From ba181197c2fb6f433eb7dd7c5f031bb64c3fad9b Mon Sep 17 00:00:00 2001 From: themanforfree Date: Thu, 18 Jul 2024 18:58:46 +0800 Subject: [PATCH] feat: api_key support for xinference (#6417) Signed-off-by: themanforfree --- .../model_providers/xinference/llm/llm.py | 4 ++- .../xinference/rerank/rerank.py | 26 ++++++++++----- .../xinference/speech2text/speech2text.py | 33 +++++++++++-------- .../text_embedding/text_embedding.py | 17 +++++----- .../xinference/xinference.yaml | 9 +++++ 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 0ef63f8e23..988bb0ce44 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] + api_key = credentials.get('api_key') or "abc" + client = OpenAI( base_url=f'{credentials["server_url"]}/v1', - api_key='abc', + api_key=api_key, max_retries=3, timeout=60, ) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 17b85862c9..649898f47a 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel): docs=[] ) - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + api_key = credentials.get('api_key') + if server_url.endswith('/'): + server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + + try: + handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) + response = handle.rerank( + documents=docs, + query=query, + top_n=top_n, + ) + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) - handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) - response = handle.rerank( - documents=docs, - query=query, - top_n=top_n, - ) rerank_documents = [] for idx, result in enumerate(response['results']): @@ -102,7 +110,7 @@ class XinferenceRerankModel(RerankModel): if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( 'please check model type, the model you want to invoke is not a rerank model') - + self.invoke( model=model, credentials=credentials, diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index f60d8d3443..9ee3621317 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel): } def _speech2text_invoke( - self, - model: str, - credentials: dict, + self, + model: str, + credentials: dict, file: IO[bytes], language: Optional[str] = None, prompt: Optional[str] = None, @@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + api_key = credentials.get('api_key') + if server_url.endswith('/'): + server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} - handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) - response = handle.transcriptions( - audio=file, - language = language, - prompt = prompt, - response_format = response_format, - temperature = temperature - ) + try: + handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) + response = handle.transcriptions( + audio=file, + language=language, + prompt=prompt, + response_format=response_format, + temperature=temperature + ) + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) return response["text"] diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e8429cecd4..11f1e29cb3 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ server_url = credentials['server_url'] model_uid = credentials['model_uid'] - + api_key = credentials.get('api_key') if server_url.endswith('/'): server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) + handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) embeddings = handle.create_embedding(input=texts) except RuntimeError as e: - raise InvokeServerUnavailableError(e) - + raise InvokeServerUnavailableError(str(e)) + """ for convenience, the response json is like: class Embedding(TypedDict): @@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + server_url = credentials['server_url'] model_uid = credentials['model_uid'] extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) @@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = server_url[:-1] client = Client(base_url=server_url) - + try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: @@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): KeyError ] } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ used to define customizable model schema """ - + entity = AIModelEntity( model=model, label=I18nObject( diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index 28ffc0389e..9496c66fdd 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -46,3 +46,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的Model UID en_US: Enter the model uid + - variable: api_key + label: + zh_Hans: API密钥 + en_US: API key + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的API密钥 + en_US: Enter the api key