add example api url endpoint in placeholder (#1887)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Chenhe Gu 2024-01-04 01:16:51 +08:00 committed by GitHub
parent 5ca4c4a44d
commit 77f9e8ce0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 57 deletions

View File

@ -1,5 +1,6 @@
import logging
from decimal import Decimal
from urllib.parse import urljoin
import requests
import json
@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError
@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:return:
"""
return self._num_tokens_from_messages(model, prompt_messages, tools)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model
data = {
@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping"
},
]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.MODE: 'chat'
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: credentials.get('mode'),
},
parameter_rules=[
ParameterRule(
@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = {
"model": model,
"stream": stream,
@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")
@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
if stop:
@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if user:
data["user"] = user
response = requests.post(
endpoint_url,
headers=headers,
@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
if len(chunk_json['choices']) == 0:
if not chunk_json or len(chunk_json['choices']) == 0:
continue
delta = chunk_json['choices'][0]['delta']
chunk_index = chunk_json['choices'][0]['index']
choice = chunk_json['choices'][0]
chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
# check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk(
index=chunk_index,
message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason']
)
else:
yield LLMResultChunk(
model=model,
@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="End of stream."
)
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
chunk_index += 1
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
response_json = response.json()
completion_type = LLMMode.value_of(credentials['mode'])
@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls]
# function_call = message.tool_calls[0]
# message_dict["function_call"] = {
@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"""
Approximate num tokens with GPT2 tokenizer.
"""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]:

View File

@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 API endpoint URL
en_US: Enter your API endpoint URL
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type

View File

@ -1,6 +1,7 @@
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {}
if user:
@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = {
'input': 'ping',
@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],