fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)

This commit is contained in:
-LAN- 2024-06-24 15:35:21 +08:00 committed by GitHub
parent 41ceb6a4eb
commit ba67206bb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 142 deletions

View File

@ -1,14 +1,13 @@
import copy
import logging
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
return self._chat_generate(
model=model,
@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
user=user
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f'Base Model Name {base_model_name} is invalid')
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else:
# text completion model, do not support tool calling
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials,content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials:
@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage
@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ''
for chunk in response:
if len(chunk.choices) == 0:
@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage
@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
extra_model_kwargs = {}
if tools:
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
extra_model_kwargs['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
# "parameters": tool.parameters
# } for tool in tools]
if stop:
extra_model_kwargs['stop'] = stop
@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
extra_model_kwargs['user'] = user
# chat model
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
messages=messages,
model=model,
stream=stream,
**model_parameters,
@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
def _handle_chat_generate_response(
self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
):
assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls
assistant_message_function_call = assistant_message.function_call
assistant_message_tool_calls = assistant_message.tool_calls
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
response = LLMResult(
result = LLMResult(
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
system_fingerprint=response.system_fingerprint,
)
return response
return result
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
):
index = 0
full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model
system_fingerprint = None
completion = ''
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.delta is None or (
delta.finish_reason is None
and (delta.delta.content is None or delta.delta.content == '')
and delta.delta.function_call is None
):
if delta.finish_reason is None and not delta.delta.content:
continue
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
)
@staticmethod
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)
return tool_calls
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)
@staticmethod
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-> AssistantPromptMessage.ToolCall:
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=response_function_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name,
type="function",
function=function
)
return tool_call
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
# message.tool_calls]
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
# message_dict = {
# "role": "tool",
# "content": message.content,
# "tool_call_id": message.tool_call_id
# }
message_dict = {
"role": "function",
"role": "tool",
"name": message.name,
"content": message.content,
"name": message.tool_call_id
"tool_call_id": message.tool_call_id
}
else:
raise ValueError(f"Got unknown type {message}")
@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode(parameters['title']))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
num_tokens += len(encoding.encode(parameters['type']))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
for key, value in parameters['properties'].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@ -73,17 +73,15 @@ class MockChatClass:
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Optional[list[ChatCompletionMessageToolCall]]:
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool: ChatCompletionToolParam = tools[0]
tool = tools[0]
if tools['type'] != 'function':
if 'type' in tools and tools['type'] != 'function':
return None
function = tool['function']
function_call = MockChatClass.generate_function_call(functions=[function])