mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)
This commit is contained in:
parent
41ceb6a4eb
commit
ba67206bb9
|
@ -1,14 +1,13 @@
|
|||
import copy
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
from openai.types import Completion
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
|
@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
|
||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
|
@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
user=user
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
||||
ModelPropertyKey.MODE)
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
if not model_entity:
|
||||
raise ValueError(f'Base Model Name {base_model_name} is invalid')
|
||||
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
else:
|
||||
# text completion model, do not support tool calling
|
||||
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
return self._num_tokens_from_string(credentials,content)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if 'openai_api_base' not in credentials:
|
||||
|
@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
if not base_model_name:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
|
@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]
|
||||
):
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
|
@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
||||
|
||||
# transform usage
|
||||
|
@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
full_text = ''
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
|
@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
||||
|
||||
# transform usage
|
||||
|
@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
extra_model_kwargs['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
# extra_model_kwargs['functions'] = [{
|
||||
# "name": tool.name,
|
||||
# "description": tool.description,
|
||||
# "parameters": tool.parameters
|
||||
# } for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
extra_model_kwargs['user'] = user
|
||||
|
||||
# chat model
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
messages=messages,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
|
@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self, model: str, credentials: dict, response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
):
|
||||
assistant_message = response.choices[0].message
|
||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||
assistant_message_function_call = assistant_message.function_call
|
||||
assistant_message_tool_calls = assistant_message.tool_calls
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
tool_calls = []
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
|
@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
result = LLMResult(
|
||||
model=response.model or model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
|
@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return response
|
||||
return result
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
):
|
||||
index = 0
|
||||
full_assistant_content = ''
|
||||
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
||||
real_model = model
|
||||
system_fingerprint = None
|
||||
completion = ''
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
# extract tool calls from response
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
||||
|
||||
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
||||
if delta.delta is None or (
|
||||
delta.finish_reason is None
|
||||
and (delta.delta.content is None or delta.delta.content == '')
|
||||
and delta.delta.function_call is None
|
||||
):
|
||||
if delta.finish_reason is None and not delta.delta.content:
|
||||
continue
|
||||
|
||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||
assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if delta_assistant_message_function_call_storage is not None:
|
||||
# handle process of stream function call
|
||||
if assistant_message_function_call:
|
||||
# message has not ended ever
|
||||
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||
continue
|
||||
else:
|
||||
# message has ended
|
||||
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||
delta_assistant_message_function_call_storage = None
|
||||
else:
|
||||
if assistant_message_function_call:
|
||||
# start of stream function call
|
||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||
if delta_assistant_message_function_call_storage.arguments is None:
|
||||
delta_assistant_message_function_call_storage.arguments = ''
|
||||
continue
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
|
@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
|
||||
if tool_calls_response:
|
||||
for response_tool_call in tool_calls_response:
|
||||
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
||||
index = response_tool_call.index
|
||||
if index < len(tool_calls):
|
||||
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
||||
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
||||
if response_tool_call.function:
|
||||
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
|
||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
|
||||
else:
|
||||
assert response_tool_call.id is not None
|
||||
assert response_tool_call.type is not None
|
||||
assert response_tool_call.function is not None
|
||||
assert response_tool_call.function.name is not None
|
||||
assert response_tool_call.function.arguments is not None
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
||||
-> AssistantPromptMessage.ToolCall:
|
||||
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.name,
|
||||
arguments=response_function_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.name,
|
||||
type="function",
|
||||
function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
@staticmethod
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
assert message.content is not None
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
|
@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
|
||||
# message.tool_calls]
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"role": "tool",
|
||||
"name": message.name,
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self, credentials: dict, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
|
@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
assert isinstance(tool_call, dict)
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
|
@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode(parameters['title']))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
num_tokens += len(encoding.encode(parameters['type']))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
for key, value in parameters['properties'].items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
|
@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str):
|
||||
for ai_model_entity in LLM_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
|
@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
|
|
@ -73,17 +73,15 @@ class MockChatClass:
|
|||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool: ChatCompletionToolParam = tools[0]
|
||||
tool = tools[0]
|
||||
|
||||
if tools['type'] != 'function':
|
||||
if 'type' in tools and tools['type'] != 'function':
|
||||
return None
|
||||
|
||||
|
||||
function = tool['function']
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
|
|
Loading…
Reference in New Issue
Block a user