From 77f9e8ce0f1ed8b5242dd76f557121eb452fe83c Mon Sep 17 00:00:00 2001 From: Chenhe Gu Date: Thu, 4 Jan 2024 01:16:51 +0800 Subject: [PATCH] add example api url endpoint in placeholder (#1887) Co-authored-by: takatost --- .../openai_api_compatible/llm/llm.py | 146 ++++++++++++------ .../openai_api_compatible.yaml | 4 +- .../text_embedding/text_embedding.py | 26 +++- 3 files changed, 119 insertions(+), 57 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index bfb6aff164..71c15c7f88 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,5 +1,6 @@ import logging from decimal import Decimal +from urllib.parse import urljoin import requests import json @@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.utils import helper -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \ - PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \ +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \ + AssistantPromptMessage, PromptMessageContent, \ + PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \ + ToolPromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \ + DefaultParameterName, \ ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.errors.invoke import InvokeError @@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: """ return self._num_tokens_from_messages(model, prompt_messages, tools) - + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. @@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials['endpoint_url'] + if not endpoint_url.endswith('/'): + endpoint_url += '/' # prepare the payload for a simple ping to the model data = { @@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): "content": "ping" }, ] + endpoint_url = urljoin(endpoint_url, 'chat/completions') elif completion_type is LLMMode.COMPLETION: data['prompt'] = 'ping' + endpoint_url = urljoin(endpoint_url, 'completions') else: raise ValueError("Unsupported completion type for model configuration.") - + # send a post request to validate the credentials response = requests.post( endpoint_url, @@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ) if response.status_code != 200: - raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}') + raise CredentialsValidateFailedError( + f'Credentials validation failed with status code {response.status_code}') + try: + json_result = response.json() + except json.JSONDecodeError as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error') + + if (completion_type is LLMMode.CHAT + and ('object' not in json_result or json_result['object'] != 'chat.completion')): + raise CredentialsValidateFailedError( + f'Credentials validation failed: invalid response object, must be \'chat.completion\'') + elif (completion_type is LLMMode.COMPLETION + and ('object' not in json_result or json_result['object'] != 'text_completion')): + raise CredentialsValidateFailedError( + f'Credentials validation failed: invalid response object, must be \'text_completion\'') + except CredentialsValidateFailedError: + raise except Exception as ex: raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') @@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), - ModelPropertyKey.MODE: 'chat' + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MODE: credentials.get('mode'), }, parameter_rules=[ ParameterRule( @@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return entity - # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, \ + user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - + if not endpoint_url.endswith('/'): + endpoint_url += '/' + data = { "model": model, "stream": stream, @@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): completion_type = LLMMode.value_of(credentials['mode']) if completion_type is LLMMode.CHAT: + endpoint_url = urljoin(endpoint_url, 'chat/completions') data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] elif completion_type == LLMMode.COMPLETION: + endpoint_url = urljoin(endpoint_url, 'completions') data['prompt'] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") @@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): data["tool_choice"] = "auto" for tool in tools: - formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool))) - + formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) + data["tools"] = formatted_tools if stop: @@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if user: data["user"] = user - + response = requests.post( endpoint_url, headers=headers, @@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if chunk: decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() + chunk_json = None try: chunk_json = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: yield create_final_llm_result_chunk( - index=chunk_index + 1, + index=chunk_index + 1, message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered." ) - if len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json['choices']) == 0: continue - delta = chunk_json['choices'][0]['delta'] - chunk_index = chunk_json['choices'][0]['index'] + choice = chunk_json['choices'][0] + chunk_index = choice['index'] if 'index' in choice else chunk_index - if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''): + if 'delta' in choice: + delta = choice['delta'] + if delta.get('content') is None or delta.get('content') == '': + continue + + assistant_message_tool_calls = delta.get('tool_calls', None) + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + # function_call = self._extract_response_function_call(assistant_message_function_call) + # tool_calls = [function_call] if function_call else [] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.get('content', ''), + tool_calls=tool_calls if assistant_message_tool_calls else [] + ) + + full_assistant_content += delta.get('content', '') + elif 'text' in choice: + if choice.get('text') is None or choice.get('text') == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=choice.get('text', '') + ) + + full_assistant_content += choice.get('text', '') + else: continue - - assistant_message_tool_calls = delta.get('tool_calls', None) - # assistant_message_function_call = delta.delta.function_call - - # extract tool calls from response - if assistant_message_tool_calls: - tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - # function_call = self._extract_response_function_call(assistant_message_function_call) - # tool_calls = [function_call] if function_call else [] - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.get('content', ''), - tool_calls=tool_calls if assistant_message_tool_calls else [] - ) - - full_assistant_content += delta.get('content', '') # check payload indicator for completion if chunk_json['choices'][0].get('finish_reason') is not None: - yield create_final_llm_result_chunk( index=chunk_index, message=assistant_prompt_message, finish_reason=chunk_json['choices'][0]['finish_reason'] ) - else: yield LLMResultChunk( model=model, @@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message=AssistantPromptMessage(content=""), finish_reason="End of stream." ) - - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + + chunk_index += 1 + + def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> LLMResult: + response_json = response.json() completion_type = LLMMode.value_of(credentials['mode']) @@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in + message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call + in message.tool_calls] # function_call = message.tool_calls[0] # message_dict["function_call"] = { @@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_dict["name"] = message.name return message_dict - + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ @@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ Approximate num tokens with GPT2 tokenizer. """ - + tokens_per_message = 3 tokens_per_name = 1 - + num_tokens = 0 messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] for message in messages_dict: @@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) \ -> list[AssistantPromptMessage.ToolCall]: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 5980ad17d4..b2a4af0057 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -33,8 +33,8 @@ model_credential_schema: type: text-input required: true placeholder: - zh_Hans: 在此输入您的 API endpoint URL - en_US: Enter your API endpoint URL + zh_Hans: Base URL, eg. https://api.openai.com/v1 + en_US: Base URL, eg. https://api.openai.com/v1 - variable: mode show_on: - variable: __model_type diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 8fbbade99d..d59a30e599 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -1,6 +1,7 @@ import time from decimal import Decimal from typing import Optional +from urllib.parse import urljoin import requests import json @@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url = credentials.get('endpoint_url') + if not endpoint_url.endswith('/'): + endpoint_url += '/' - endpoint_url = credentials['endpoint_url'] + endpoint_url = urljoin(endpoint_url, 'embeddings') extra_model_kwargs = {} if user: @@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url = credentials.get('endpoint_url') + if not endpoint_url.endswith('/'): + endpoint_url += '/' - endpoint_url = credentials['endpoint_url'] + endpoint_url = urljoin(endpoint_url, 'embeddings') payload = { 'input': 'ping', @@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ) if response.status_code != 200: - raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}") + raise CredentialsValidateFailedError( + f'Credentials validation failed with status code {response.status_code}') + try: + json_result = response.json() + except json.JSONDecodeError as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error') + + if 'model' not in json_result: + raise CredentialsValidateFailedError( + f'Credentials validation failed: invalid response') + except CredentialsValidateFailedError: + raise except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[],