mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: api_key support for xinference (#6417)
Signed-off-by: themanforfree <themanforfree@gmail.com>
This commit is contained in:
parent
218930c897
commit
ba181197c2
|
@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
if credentials['server_url'].endswith('/'):
|
if credentials['server_url'].endswith('/'):
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
|
api_key = credentials.get('api_key') or "abc"
|
||||||
|
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
base_url=f'{credentials["server_url"]}/v1',
|
base_url=f'{credentials["server_url"]}/v1',
|
||||||
api_key='abc',
|
api_key=api_key,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel):
|
||||||
docs=[]
|
docs=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials['server_url'].endswith('/'):
|
server_url = credentials['server_url']
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
model_uid = credentials['model_uid']
|
||||||
|
api_key = credentials.get('api_key')
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
|
||||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
try:
|
||||||
|
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
||||||
response = handle.rerank(
|
response = handle.rerank(
|
||||||
documents=docs,
|
documents=docs,
|
||||||
query=query,
|
query=query,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
)
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
for idx, result in enumerate(response['results']):
|
for idx, result in enumerate(response['results']):
|
||||||
|
|
|
@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||||
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
||||||
:return: text for given audio file
|
:return: text for given audio file
|
||||||
"""
|
"""
|
||||||
if credentials['server_url'].endswith('/'):
|
server_url = credentials['server_url']
|
||||||
credentials['server_url'] = credentials['server_url'][:-1]
|
model_uid = credentials['model_uid']
|
||||||
|
api_key = credentials.get('api_key')
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
|
||||||
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
try:
|
||||||
|
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
|
||||||
response = handle.transcriptions(
|
response = handle.transcriptions(
|
||||||
audio=file,
|
audio=file,
|
||||||
language = language,
|
language=language,
|
||||||
prompt = prompt,
|
prompt=prompt,
|
||||||
response_format = response_format,
|
response_format=response_format,
|
||||||
temperature = temperature
|
temperature=temperature
|
||||||
)
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
return response["text"]
|
return response["text"]
|
||||||
|
|
||||||
|
|
|
@ -43,15 +43,16 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
server_url = credentials['server_url']
|
server_url = credentials['server_url']
|
||||||
model_uid = credentials['model_uid']
|
model_uid = credentials['model_uid']
|
||||||
|
api_key = credentials.get('api_key')
|
||||||
if server_url.endswith('/'):
|
if server_url.endswith('/'):
|
||||||
server_url = server_url[:-1]
|
server_url = server_url[:-1]
|
||||||
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
|
||||||
embeddings = handle.create_embedding(input=texts)
|
embeddings = handle.create_embedding(input=texts)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise InvokeServerUnavailableError(e)
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for convenience, the response json is like:
|
for convenience, the response json is like:
|
||||||
|
|
|
@ -46,3 +46,12 @@ model_credential_schema:
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的Model UID
|
zh_Hans: 在此输入您的Model UID
|
||||||
en_US: Enter the model uid
|
en_US: Enter the model uid
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
zh_Hans: API密钥
|
||||||
|
en_US: API key
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的API密钥
|
||||||
|
en_US: Enter the api key
|
||||||
|
|
Loading…
Reference in New Issue
Block a user