Feat/blocking function call (#2247)

This commit is contained in:
Yeuoly 2024-01-30 15:25:37 +08:00 committed by GitHub
parent 1ea18a2922
commit 6d5b386394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 429 additions and 94 deletions

View File

@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
memory=memory, memory=memory,
) )
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(
@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_cot_runner.run( invoke_result = assistant_cot_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,
@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_fc_runner.run( invoke_result = assistant_fc_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,

View File

@ -1,7 +1,7 @@
import logging import logging
import json import json
from typing import Optional, List, Tuple, Union from typing import Optional, List, Tuple, Union, cast
from datetime import datetime from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension
@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_manager import ModelInstance
from core.file.message_file_parser import FileTransferMethod from core.file.message_file_parser import FileTransferMethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages: Optional[List[PromptMessage]] = None, prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None, db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None: ) -> None:
""" """
Agent runner Agent runner
@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.history_prompt_messages = prompt_messages self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback # init callback
self.agent_callback = DifyAgentCallbackHandler() self.agent_callback = DifyAgentCallbackHandler()
@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
""" """
Repacket app orchestration config Repacket app orchestration config

View File

@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance, def run(self, conversation: Conversation,
conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
) -> Union[Generator, LLMResult]: ) -> Union[Generator, LLMResult]:
@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = False function_call_state = False
@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought # remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input: if action_name and action_input is not None:
return AgentScratchpadUnit( return AgentScratchpadUnit(
agent_response=content, agent_response=content,
thought=agent_thought, thought=agent_thought,

View File

@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom from core.application_queue_manager import PublishFrom
@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance, def run(self, conversation: Conversation,
conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
) -> Generator[LLMResultChunk, None, None]: ) -> Generator[LLMResultChunk, None, None]:
@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False function_call_state = False
@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens # recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages) self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model # invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools, tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop, stop=app_orchestration_config.model_config.stop,
stream=True, stream=self.stream_tool_call,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
) )
@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage = None current_llm_usage = None
for chunk in chunks: if self.stream_tool_call:
for chunk in chunks:
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk)) tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({
@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call[1]: tool_call[2] for tool_call in tool_calls tool_call[1]: tool_call[2] for tool_call in tool_calls
}) })
if chunk.delta.message and chunk.delta.message.content: if result.usage:
if isinstance(chunk.delta.message.content, list): increase_usage(llm_usage, result.usage)
for content in chunk.delta.message.content: current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data response += content.data
else: else:
response += chunk.delta.message.content response += result.message.content
if chunk.delta.usage: if not result.message.content:
increase_usage(llm_usage, chunk.delta.usage) result.message.content = ''
current_llm_usage = chunk.delta.usage
yield chunk yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
)
)
if tool_calls:
prompt_messages.append(AssistantPromptMessage(
content='',
name='',
tool_calls=[AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls]
))
# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
final_answer += response + '\n' final_answer += response + '\n'
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# call tools # call tools
tool_responses = [] tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls: for tool_call_id, tool_call_name, tool_call_args in tool_calls:
@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
) )
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# update prompt tool # update prompt tool
for prompt_tool in prompt_messages_tools: for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if llm_result_chunk.delta.message.tool_calls: if llm_result_chunk.delta.message.tool_calls:
return True return True
return False return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
""" """
@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)) ))
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))
return tool_calls
def organize_prompt_messages(self, prompt_template: str, def organize_prompt_messages(self, prompt_template: str,
query: str = None, query: str = None,

View File

@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL = "multi-tool-call" MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum): class DefaultParameterName(Enum):

View File

@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={

View File

@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> Generator: tools: Optional[list[PromptMessageTool]] = None) -> Generator:
index = 0 index = 0
full_assistant_content = '' full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model real_model = model
system_fingerprint = None system_fingerprint = None
completion = '' completion = ''
@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta = chunk.choices[0] delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
delta.delta.function_call is None:
continue continue
# assistant_message_tool_calls = delta.delta.tool_calls # assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
# extract tool calls from response # extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call) function_call = self._extract_response_function_call(assistant_message_function_call)
@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if message.name is not None: if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name
return message_dict return message_dict
@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens = 0 num_tokens = 0
for tool in tools: for tool in tools:
num_tokens += len(encoding.encode('type')) num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(tool.get("type")))
num_tokens += len(encoding.encode('function')) num_tokens += len(encoding.encode('function'))
# calculate num tokens for function object # calculate num tokens for function object

View File

@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
PromptMessageTool, SystemPromptMessage, UserPromptMessage) PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")

View File

@ -4,6 +4,8 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16384 context_size: 16384

View File

@ -4,6 +4,8 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768

View File

@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
continue continue
for choice in choices: for choice in choices:
print(choice)
message = choice['delta'] message = choice['delta']
yield MinimaxMessage( yield MinimaxMessage(
content=message, content=message,

View File

@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
**extra_kwargs **extra_kwargs
} }
if tools:
body['functions'] = tools
body['function_call'] = { 'type': 'auto' }
try: try:
response = post( response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
""" """
handle stream chat generate response handle stream chat generate response
""" """
function_call_storage = None
for line in response.iter_lines(): for line in response.iter_lines():
if not line: if not line:
continue continue
@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
msg = data['base_resp']['status_msg'] msg = data['base_resp']['status_msg']
self._handle_error(code, msg) self._handle_error(code, msg)
if data['reply']: if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens'] total_tokens = data['usage']['total_tokens']
message = MinimaxMessage( message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value, role=MinimaxMessage.Role.ASSISTANT.value,
@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
'total_tokens': total_tokens 'total_tokens': total_tokens
} }
message.stop_reason = data['choices'][0]['finish_reason'] message.stop_reason = data['choices'][0]['finish_reason']
if function_call_storage:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = function_call_storage
yield function_call_message
yield message yield message
return return
@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
continue continue
for choice in choices: for choice in choices:
message = choice['messages'][0]['text'] message = choice['messages'][0]
if not message:
continue if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None
yield MinimaxMessage( minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
content=message,
role=MinimaxMessage.Role.ASSISTANT.value if 'function_call' in message:
) minimax_message.function_call = message['function_call']
if 'text' in message:
minimax_message.content = message['text']
yield minimax_message

View File

@ -2,7 +2,7 @@ from typing import Generator, List
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage) SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
""" """
client: MinimaxChatCompletionPro = self.model_apis[model]() client: MinimaxChatCompletionPro = self.model_apis[model]()
if tools:
tools = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
response = client.generate( response = client.generate(
model=model, model=model,
api_key=credentials['minimax_api_key'], api_key=credentials['minimax_api_key'],
@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
elif isinstance(prompt_message, UserPromptMessage): elif isinstance(prompt_message, UserPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
elif isinstance(prompt_message, AssistantPromptMessage): elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.function_call={
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
return message
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
elif isinstance(prompt_message, ToolPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
else: else:
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason if message.stop_reason else None,
), ),
) )
elif message.function_call:
if 'name' not in message.function_call or 'arguments' not in message.function_call:
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content='',
tool_calls=[AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call['name'],
arguments=message.function_call['arguments']
)
)]
),
),
)
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,

View File

@ -7,13 +7,23 @@ class MinimaxMessage:
USER = 'USER' USER = 'USER'
ASSISTANT = 'BOT' ASSISTANT = 'BOT'
SYSTEM = 'SYSTEM' SYSTEM = 'SYSTEM'
FUNCTION = 'FUNCTION'
role: str = Role.USER.value role: str = Role.USER.value
content: str content: str
usage: Dict[str, int] = None usage: Dict[str, int] = None
stop_reason: str = '' stop_reason: str = ''
function_call: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
'sender_type': 'BOT',
'sender_name': '专家',
'text': '',
'function_call': self.function_call
}
return { return {
'sender_type': self.role, 'sender_type': self.role,
'sender_name': '' if self.role == 'USER' else '专家', 'sender_name': '' if self.role == 'USER' else '专家',

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 8192 context_size: 8192

View File

@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if message.name is not None: if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name
return message_dict return message_dict

View File

@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage) SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
ParameterRule, ParameterType) ParameterRule, ParameterType, ModelFeature)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter) XinferenceModelExtraParameter)
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
""" """
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99
return self._generate( return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user, tools=tools, stop=stop, stream=stream, user=user,
@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion' credentials['completion_type'] = 'completion'
else: else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
if extra_param.support_function_call:
credentials['support_function_call'] = True
except RuntimeError as e: except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label=I18nObject( label=I18nObject(
zh_Hans='温度', zh_Hans='温度',
en_US='Temperature' en_US='Temperature'
) ),
), ),
ParameterRule( ParameterRule(
name='top_p', name='top_p',
@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={ model_properties={
ModelPropertyKey.MODE: completion_type, ModelPropertyKey.MODE: completion_type,
}, },
@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
""" """
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
client = OpenAI( client = OpenAI(
base_url=f'{credentials["server_url"]}/v1', base_url=f'{credentials["server_url"]}/v1',
api_key='abc', api_key='abc',

View File

@ -2,7 +2,7 @@ import time
from typing import Optional from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
class XinferenceTextEmbeddingModel(TextEmbeddingModel): class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url) client = Client(base_url=server_url)
try: try:
@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return: :return:
""" """
try: try:
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
self._invoke(model=model, credentials=credentials, texts=['ping']) self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError: except (InvokeAuthorizationError, RuntimeError):
raise CredentialsValidateFailedError('Invalid api key') raise CredentialsValidateFailedError('Invalid api key')
@property @property
@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
used to define customizable model schema used to define customizable model schema
""" """
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model_properties={}, model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[] parameter_rules=[]
) )

View File

@ -1,6 +1,7 @@
from threading import Lock from threading import Lock
from time import time from time import time
from typing import List from typing import List
from os import path
from requests import get from requests import get
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
model_format: str model_format: str
model_handle_type: str model_handle_type: str
model_ability: List[str] model_ability: List[str]
max_tokens: int = 512
support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None: def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens
cache = {} cache = {}
cache_lock = Lock() cache_lock = Lock()
@ -49,7 +55,7 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """
url = f'{server_url}/v1/models/{model_uid}' url = path.join(server_url, 'v1/models', model_uid)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
@ -66,10 +72,12 @@ class XinferenceHelper:
response_json = response.json() response_json = response.json()
model_format = response_json['model_format'] model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json['model_ability'] model_ability = response_json.get('model_ability', [])
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm' model_handle_type = 'chatglm'
elif 'generate' in model_ability: elif 'generate' in model_ability:
model_handle_type = 'generate' model_handle_type = 'generate'
@ -78,8 +86,13 @@ class XinferenceHelper:
else: else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
return XinferenceModelExtraParameter( return XinferenceModelExtraParameter(
model_format=model_format, model_format=model_format,
model_handle_type=model_handle_type, model_handle_type=model_handle_type,
model_ability=model_ability model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens
) )

View File

@ -2,6 +2,10 @@ model: glm-3-turbo
label: label:
en_US: glm-3-turbo en_US: glm-3-turbo
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:

View File

@ -2,6 +2,10 @@ model: glm-4
label: label:
en_US: glm-4 en_US: glm-4
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:

View File

@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
'content': prompt_message.content, 'content': prompt_message.content,
'tool_call_id': prompt_message.tool_call_id 'tool_call_id': prompt_message.tool_call_id
}) })
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content,
'tool_calls': [
{
'id': tool_call.id,
'type': tool_call.type,
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments
}
} for tool_call in prompt_message.tool_calls
]
})
else:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content
})
else: else:
params['messages'].append({ params['messages'].append({
'role': prompt_message.role.value, 'role': prompt_message.role.value,

View File

@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4 huggingface_hub~=0.16.4
transformers~=4.31.0 transformers~=4.31.0
pandas==1.5.3 pandas==1.5.3
xinference-client~=0.6.4 xinference-client~=0.8.1
safetensors==0.3.2 safetensors==0.3.2
zhipuai==1.0.7 zhipuai==1.0.7
werkzeug~=3.0.1 werkzeug~=3.0.1

View File

@ -19,58 +19,86 @@ class MockXinferenceClass(object):
raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
if 'generate' == model_uid: if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url) return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid: if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url) return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid: if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url) return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid: if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url) return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs): def get(self: Session, url: str, **kwargs):
if '/v1/models/' in url: response = Response()
response = Response() if 'v1/models/' in url:
# get model uid # get model uid
model_uid = url.split('/')[-1] model_uid = url.split('/')[-1]
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found') return response
# check if url is valid # check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found') return response
if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response
elif 'v1/cluster/auth' in url:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "LLM", "auth": true
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}''' }'''
return response return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid # check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)