feat: api_key support for xinference (#6417)

Signed-off-by: themanforfree <themanforfree@gmail.com>
This commit is contained in:
themanforfree 2024-07-18 18:58:46 +08:00 committed by GitHub
parent 218930c897
commit ba181197c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 58 additions and 31 deletions

View File

@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if credentials['server_url'].endswith('/'): if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
api_key = credentials.get('api_key') or "abc"
client = OpenAI( client = OpenAI(
base_url=f'{credentials["server_url"]}/v1', base_url=f'{credentials["server_url"]}/v1',
api_key='abc', api_key=api_key,
max_retries=3, max_retries=3,
timeout=60, timeout=60,
) )

View File

@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel):
docs=[] docs=[]
) )
if credentials['server_url'].endswith('/'): server_url = credentials['server_url']
credentials['server_url'] = credentials['server_url'][:-1] model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) try:
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
response = handle.rerank( response = handle.rerank(
documents=docs, documents=docs,
query=query, query=query,
top_n=top_n, top_n=top_n,
) )
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
rerank_documents = [] rerank_documents = []
for idx, result in enumerate(response['results']): for idx, result in enumerate(response['results']):

View File

@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
:return: text for given audio file :return: text for given audio file
""" """
if credentials['server_url'].endswith('/'): server_url = credentials['server_url']
credentials['server_url'] = credentials['server_url'][:-1] model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) try:
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
response = handle.transcriptions( response = handle.transcriptions(
audio=file, audio=file,
language = language, language=language,
prompt = prompt, prompt=prompt,
response_format = response_format, response_format=response_format,
temperature = temperature temperature=temperature
) )
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
return response["text"] return response["text"]

View File

@ -43,15 +43,16 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'): if server_url.endswith('/'):
server_url = server_url[:-1] server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try: try:
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
embeddings = handle.create_embedding(input=texts) embeddings = handle.create_embedding(input=texts)
except RuntimeError as e: except RuntimeError as e:
raise InvokeServerUnavailableError(e) raise InvokeServerUnavailableError(str(e))
""" """
for convenience, the response json is like: for convenience, the response json is like:

View File

@ -46,3 +46,12 @@ model_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入您的Model UID zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid en_US: Enter the model uid
- variable: api_key
label:
zh_Hans: API密钥
en_US: API key
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的API密钥
en_US: Enter the api key