mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat/blocking function call (#2247)
This commit is contained in:
parent
1ea18a2922
commit
6d5b386394
|
@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||||
|
@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
|
||||||
memory=memory,
|
memory=memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# change function call strategy based on LLM model
|
||||||
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
|
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||||
|
|
||||||
|
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
|
||||||
|
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||||
|
|
||||||
# start agent runner
|
# start agent runner
|
||||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||||
|
@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
|
||||||
prompt_messages=prompt_message,
|
prompt_messages=prompt_message,
|
||||||
variables_pool=tool_variables,
|
variables_pool=tool_variables,
|
||||||
db_variables=tool_conversation_variables,
|
db_variables=tool_conversation_variables,
|
||||||
|
model_instance=model_instance
|
||||||
)
|
)
|
||||||
invoke_result = assistant_cot_runner.run(
|
invoke_result = assistant_cot_runner.run(
|
||||||
model_instance=model_instance,
|
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
query=query,
|
query=query,
|
||||||
|
@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
|
||||||
memory=memory,
|
memory=memory,
|
||||||
prompt_messages=prompt_message,
|
prompt_messages=prompt_message,
|
||||||
variables_pool=tool_variables,
|
variables_pool=tool_variables,
|
||||||
db_variables=tool_conversation_variables
|
db_variables=tool_conversation_variables,
|
||||||
|
model_instance=model_instance
|
||||||
)
|
)
|
||||||
invoke_result = assistant_fc_runner.run(
|
invoke_result = assistant_fc_runner.run(
|
||||||
model_instance=model_instance,
|
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
query=query,
|
query=query,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Optional, List, Tuple, Union
|
from typing import Optional, List, Tuple, Union, cast
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
|
||||||
|
@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
|
||||||
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
|
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
from core.file.message_file_parser import FileTransferMethod
|
from core.file.message_file_parser import FileTransferMethod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||||
db_variables: Optional[ToolConversationVariables] = None,
|
db_variables: Optional[ToolConversationVariables] = None,
|
||||||
|
model_instance: ModelInstance = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Agent runner
|
Agent runner
|
||||||
|
@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||||
self.history_prompt_messages = prompt_messages
|
self.history_prompt_messages = prompt_messages
|
||||||
self.variables_pool = variables_pool
|
self.variables_pool = variables_pool
|
||||||
self.db_variables_pool = db_variables
|
self.db_variables_pool = db_variables
|
||||||
|
self.model_instance = model_instance
|
||||||
|
|
||||||
# init callback
|
# init callback
|
||||||
self.agent_callback = DifyAgentCallbackHandler()
|
self.agent_callback = DifyAgentCallbackHandler()
|
||||||
|
@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||||
MessageAgentThought.message_id == self.message.id,
|
MessageAgentThought.message_id == self.message.id,
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
|
# check if model supports stream tool call
|
||||||
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
|
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||||
|
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||||
|
self.stream_tool_call = True
|
||||||
|
else:
|
||||||
|
self.stream_tool_call = False
|
||||||
|
|
||||||
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||||
"""
|
"""
|
||||||
Repacket app orchestration config
|
Repacket app orchestration config
|
||||||
|
|
|
@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||||
from models.model import Conversation, Message
|
from models.model import Conversation, Message
|
||||||
|
|
||||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
def run(self, model_instance: ModelInstance,
|
def run(self, conversation: Conversation,
|
||||||
conversation: Conversation,
|
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
) -> Union[Generator, LLMResult]:
|
) -> Union[Generator, LLMResult]:
|
||||||
|
@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
llm_usage.prompt_price += usage.prompt_price
|
llm_usage.prompt_price += usage.prompt_price
|
||||||
llm_usage.completion_price += usage.completion_price
|
llm_usage.completion_price += usage.completion_price
|
||||||
|
|
||||||
|
model_instance = self.model_instance
|
||||||
|
|
||||||
while function_call_state and iteration_step <= max_iteration_steps:
|
while function_call_state and iteration_step <= max_iteration_steps:
|
||||||
# continue to run until there is not any tool call
|
# continue to run until there is not any tool call
|
||||||
function_call_state = False
|
function_call_state = False
|
||||||
|
@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
# remove Action: xxx from agent thought
|
# remove Action: xxx from agent thought
|
||||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||||
|
|
||||||
if action_name and action_input:
|
if action_name and action_input is not None:
|
||||||
return AgentScratchpadUnit(
|
return AgentScratchpadUnit(
|
||||||
agent_response=content,
|
agent_response=content,
|
||||||
thought=agent_thought,
|
thought=agent_thought,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
||||||
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.application_queue_manager import PublishFrom
|
from core.application_queue_manager import PublishFrom
|
||||||
|
|
||||||
|
@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
def run(self, model_instance: ModelInstance,
|
def run(self, conversation: Conversation,
|
||||||
conversation: Conversation,
|
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
) -> Generator[LLMResultChunk, None, None]:
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
|
@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
llm_usage.prompt_price += usage.prompt_price
|
llm_usage.prompt_price += usage.prompt_price
|
||||||
llm_usage.completion_price += usage.completion_price
|
llm_usage.completion_price += usage.completion_price
|
||||||
|
|
||||||
|
model_instance = self.model_instance
|
||||||
|
|
||||||
while function_call_state and iteration_step <= max_iteration_steps:
|
while function_call_state and iteration_step <= max_iteration_steps:
|
||||||
function_call_state = False
|
function_call_state = False
|
||||||
|
|
||||||
|
@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
# recale llm max tokens
|
# recale llm max tokens
|
||||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||||
# invoke model
|
# invoke model
|
||||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters=app_orchestration_config.model_config.parameters,
|
model_parameters=app_orchestration_config.model_config.parameters,
|
||||||
tools=prompt_messages_tools,
|
tools=prompt_messages_tools,
|
||||||
stop=app_orchestration_config.model_config.stop,
|
stop=app_orchestration_config.model_config.stop,
|
||||||
stream=True,
|
stream=self.stream_tool_call,
|
||||||
user=self.user_id,
|
user=self.user_id,
|
||||||
callbacks=[],
|
callbacks=[],
|
||||||
)
|
)
|
||||||
|
@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
|
|
||||||
current_llm_usage = None
|
current_llm_usage = None
|
||||||
|
|
||||||
for chunk in chunks:
|
if self.stream_tool_call:
|
||||||
|
for chunk in chunks:
|
||||||
|
# check if there is any tool call
|
||||||
|
if self.check_tool_calls(chunk):
|
||||||
|
function_call_state = True
|
||||||
|
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||||
|
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||||
|
try:
|
||||||
|
tool_call_inputs = json.dumps({
|
||||||
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# ensure ascii to avoid encoding error
|
||||||
|
tool_call_inputs = json.dumps({
|
||||||
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
if chunk.delta.message and chunk.delta.message.content:
|
||||||
|
if isinstance(chunk.delta.message.content, list):
|
||||||
|
for content in chunk.delta.message.content:
|
||||||
|
response += content.data
|
||||||
|
else:
|
||||||
|
response += chunk.delta.message.content
|
||||||
|
|
||||||
|
if chunk.delta.usage:
|
||||||
|
increase_usage(llm_usage, chunk.delta.usage)
|
||||||
|
current_llm_usage = chunk.delta.usage
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
result: LLMResult = chunks
|
||||||
# check if there is any tool call
|
# check if there is any tool call
|
||||||
if self.check_tool_calls(chunk):
|
if self.check_blocking_tool_calls(result):
|
||||||
function_call_state = True
|
function_call_state = True
|
||||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||||
try:
|
try:
|
||||||
tool_call_inputs = json.dumps({
|
tool_call_inputs = json.dumps({
|
||||||
|
@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||||
})
|
})
|
||||||
|
|
||||||
if chunk.delta.message and chunk.delta.message.content:
|
if result.usage:
|
||||||
if isinstance(chunk.delta.message.content, list):
|
increase_usage(llm_usage, result.usage)
|
||||||
for content in chunk.delta.message.content:
|
current_llm_usage = result.usage
|
||||||
|
|
||||||
|
if result.message and result.message.content:
|
||||||
|
if isinstance(result.message.content, list):
|
||||||
|
for content in result.message.content:
|
||||||
response += content.data
|
response += content.data
|
||||||
else:
|
else:
|
||||||
response += chunk.delta.message.content
|
response += result.message.content
|
||||||
|
|
||||||
if chunk.delta.usage:
|
if not result.message.content:
|
||||||
increase_usage(llm_usage, chunk.delta.usage)
|
result.message.content = ''
|
||||||
current_llm_usage = chunk.delta.usage
|
|
||||||
|
|
||||||
yield chunk
|
yield LLMResultChunk(
|
||||||
|
model=model_instance.model,
|
||||||
|
prompt_messages=result.prompt_messages,
|
||||||
|
system_fingerprint=result.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=result.message,
|
||||||
|
usage=result.usage,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
prompt_messages.append(AssistantPromptMessage(
|
||||||
|
content='',
|
||||||
|
name='',
|
||||||
|
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call[0],
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool_call[1],
|
||||||
|
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||||
|
)
|
||||||
|
) for tool_call in tool_calls]
|
||||||
|
))
|
||||||
|
|
||||||
# save thought
|
# save thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
|
@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
|
|
||||||
final_answer += response + '\n'
|
final_answer += response + '\n'
|
||||||
|
|
||||||
|
# update prompt messages
|
||||||
|
if response.strip():
|
||||||
|
prompt_messages.append(AssistantPromptMessage(
|
||||||
|
content=response,
|
||||||
|
))
|
||||||
|
|
||||||
# call tools
|
# call tools
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||||
|
@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
)
|
)
|
||||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|
||||||
# update prompt messages
|
|
||||||
if response.strip():
|
|
||||||
prompt_messages.append(AssistantPromptMessage(
|
|
||||||
content=response,
|
|
||||||
))
|
|
||||||
|
|
||||||
# update prompt tool
|
# update prompt tool
|
||||||
for prompt_tool in prompt_messages_tools:
|
for prompt_tool in prompt_messages_tools:
|
||||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||||
|
@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
if llm_result_chunk.delta.message.tool_calls:
|
if llm_result_chunk.delta.message.tool_calls:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||||
|
"""
|
||||||
|
Check if there is any blocking tool call in llm result
|
||||||
|
"""
|
||||||
|
if llm_result.message.tool_calls:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
))
|
))
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
|
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||||
|
"""
|
||||||
|
Extract blocking tool calls from llm result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
for prompt_message in llm_result.message.tool_calls:
|
||||||
|
tool_calls.append((
|
||||||
|
prompt_message.id,
|
||||||
|
prompt_message.function.name,
|
||||||
|
json.loads(prompt_message.function.arguments),
|
||||||
|
))
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
def organize_prompt_messages(self, prompt_template: str,
|
def organize_prompt_messages(self, prompt_template: str,
|
||||||
query: str = None,
|
query: str = None,
|
||||||
|
|
|
@ -78,6 +78,7 @@ class ModelFeature(Enum):
|
||||||
MULTI_TOOL_CALL = "multi-tool-call"
|
MULTI_TOOL_CALL = "multi-tool-call"
|
||||||
AGENT_THOUGHT = "agent-thought"
|
AGENT_THOUGHT = "agent-thought"
|
||||||
VISION = "vision"
|
VISION = "vision"
|
||||||
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(Enum):
|
class DefaultParameterName(Enum):
|
||||||
|
|
|
@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
|
||||||
features=[
|
features=[
|
||||||
ModelFeature.AGENT_THOUGHT,
|
ModelFeature.AGENT_THOUGHT,
|
||||||
ModelFeature.MULTI_TOOL_CALL,
|
ModelFeature.MULTI_TOOL_CALL,
|
||||||
|
ModelFeature.STREAM_TOOL_CALL,
|
||||||
],
|
],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
|
@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
|
||||||
features=[
|
features=[
|
||||||
ModelFeature.AGENT_THOUGHT,
|
ModelFeature.AGENT_THOUGHT,
|
||||||
ModelFeature.MULTI_TOOL_CALL,
|
ModelFeature.MULTI_TOOL_CALL,
|
||||||
|
ModelFeature.STREAM_TOOL_CALL,
|
||||||
],
|
],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
|
@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
|
||||||
features=[
|
features=[
|
||||||
ModelFeature.AGENT_THOUGHT,
|
ModelFeature.AGENT_THOUGHT,
|
||||||
ModelFeature.MULTI_TOOL_CALL,
|
ModelFeature.MULTI_TOOL_CALL,
|
||||||
|
ModelFeature.STREAM_TOOL_CALL,
|
||||||
],
|
],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
|
@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
|
||||||
features=[
|
features=[
|
||||||
ModelFeature.AGENT_THOUGHT,
|
ModelFeature.AGENT_THOUGHT,
|
||||||
ModelFeature.MULTI_TOOL_CALL,
|
ModelFeature.MULTI_TOOL_CALL,
|
||||||
|
ModelFeature.STREAM_TOOL_CALL,
|
||||||
],
|
],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
|
@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
|
||||||
features=[
|
features=[
|
||||||
ModelFeature.AGENT_THOUGHT,
|
ModelFeature.AGENT_THOUGHT,
|
||||||
ModelFeature.MULTI_TOOL_CALL,
|
ModelFeature.MULTI_TOOL_CALL,
|
||||||
|
ModelFeature.STREAM_TOOL_CALL,
|
||||||
],
|
],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
|
|
|
@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
||||||
index = 0
|
index = 0
|
||||||
full_assistant_content = ''
|
full_assistant_content = ''
|
||||||
|
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
||||||
real_model = model
|
real_model = model
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
completion = ''
|
completion = ''
|
||||||
|
@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
|
||||||
|
delta.delta.function_call is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||||
assistant_message_function_call = delta.delta.function_call
|
assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
if delta_assistant_message_function_call_storage is not None:
|
||||||
|
# handle process of stream function call
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# message has not ended ever
|
||||||
|
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# message has ended
|
||||||
|
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||||
|
delta_assistant_message_function_call_storage = None
|
||||||
|
else:
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# start of stream function call
|
||||||
|
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||||
|
if delta_assistant_message_function_call_storage.arguments is None:
|
||||||
|
delta_assistant_message_function_call_storage.arguments = ''
|
||||||
|
continue
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
if message.name is not None:
|
if message.name:
|
||||||
message_dict["name"] = message.name
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
num_tokens += len(encoding.encode('type'))
|
num_tokens += len(encoding.encode('type'))
|
||||||
num_tokens += len(encoding.encode(tool.get("type")))
|
|
||||||
num_tokens += len(encoding.encode('function'))
|
num_tokens += len(encoding.encode('function'))
|
||||||
|
|
||||||
# calculate num tokens for function object
|
# calculate num tokens for function object
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
|
||||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage)
|
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
# check if last message is user message
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"role": "function", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ label:
|
||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 16384
|
context_size: 16384
|
||||||
|
|
|
@ -4,6 +4,8 @@ label:
|
||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32768
|
context_size: 32768
|
||||||
|
|
|
@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
|
||||||
"""
|
"""
|
||||||
def generate(self, model: str, api_key: str, group_id: str,
|
def generate(self, model: str, api_key: str, group_id: str,
|
||||||
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
||||||
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
|
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
|
||||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||||
"""
|
"""
|
||||||
generate chat completion
|
generate chat completion
|
||||||
|
@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
print(choice)
|
|
||||||
message = choice['delta']
|
message = choice['delta']
|
||||||
yield MinimaxMessage(
|
yield MinimaxMessage(
|
||||||
content=message,
|
content=message,
|
||||||
|
|
|
@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
|
||||||
"""
|
"""
|
||||||
def generate(self, model: str, api_key: str, group_id: str,
|
def generate(self, model: str, api_key: str, group_id: str,
|
||||||
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
||||||
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
|
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
|
||||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||||
"""
|
"""
|
||||||
generate chat completion
|
generate chat completion
|
||||||
|
@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
|
||||||
**extra_kwargs
|
**extra_kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
body['functions'] = tools
|
||||||
|
body['function_call'] = { 'type': 'auto' }
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(
|
response = post(
|
||||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||||
|
@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
|
||||||
"""
|
"""
|
||||||
handle stream chat generate response
|
handle stream chat generate response
|
||||||
"""
|
"""
|
||||||
|
function_call_storage = None
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
|
||||||
msg = data['base_resp']['status_msg']
|
msg = data['base_resp']['status_msg']
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
if data['reply']:
|
if data['reply'] or 'usage' in data and data['usage']:
|
||||||
total_tokens = data['usage']['total_tokens']
|
total_tokens = data['usage']['total_tokens']
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||||
|
@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
|
||||||
'total_tokens': total_tokens
|
'total_tokens': total_tokens
|
||||||
}
|
}
|
||||||
message.stop_reason = data['choices'][0]['finish_reason']
|
message.stop_reason = data['choices'][0]['finish_reason']
|
||||||
|
|
||||||
|
if function_call_storage:
|
||||||
|
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||||
|
function_call_message.function_call = function_call_storage
|
||||||
|
yield function_call_message
|
||||||
|
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message = choice['messages'][0]['text']
|
message = choice['messages'][0]
|
||||||
if not message:
|
|
||||||
continue
|
if 'function_call' in message:
|
||||||
|
if not function_call_storage:
|
||||||
|
function_call_storage = message['function_call']
|
||||||
|
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
|
||||||
|
function_call_storage['arguments'] = ''
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
function_call_storage['arguments'] += message['function_call']['arguments']
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if function_call_storage:
|
||||||
|
message['function_call'] = function_call_storage
|
||||||
|
function_call_storage = None
|
||||||
|
|
||||||
yield MinimaxMessage(
|
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||||
content=message,
|
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value
|
if 'function_call' in message:
|
||||||
)
|
minimax_message.function_call = message['function_call']
|
||||||
|
|
||||||
|
if 'text' in message:
|
||||||
|
minimax_message.content = message['text']
|
||||||
|
|
||||||
|
yield minimax_message
|
|
@ -2,7 +2,7 @@ from typing import Generator, List
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||||
SystemPromptMessage, UserPromptMessage)
|
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
tools = [{
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters
|
||||||
|
} for tool in tools]
|
||||||
|
|
||||||
response = client.generate(
|
response = client.generate(
|
||||||
model=model,
|
model=model,
|
||||||
api_key=credentials['minimax_api_key'],
|
api_key=credentials['minimax_api_key'],
|
||||||
|
@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||||
elif isinstance(prompt_message, UserPromptMessage):
|
elif isinstance(prompt_message, UserPromptMessage):
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
|
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
|
||||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||||
|
if prompt_message.tool_calls:
|
||||||
|
message = MinimaxMessage(
|
||||||
|
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||||
|
content=''
|
||||||
|
)
|
||||||
|
message.function_call={
|
||||||
|
'name': prompt_message.tool_calls[0].function.name,
|
||||||
|
'arguments': prompt_message.tool_calls[0].function.arguments
|
||||||
|
}
|
||||||
|
return message
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
|
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
|
||||||
|
elif isinstance(prompt_message, ToolPromptMessage):
|
||||||
|
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
|
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
|
||||||
|
|
||||||
|
@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif message.function_call:
|
||||||
|
if 'name' not in message.function_call or 'arguments' not in message.function_call:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content='',
|
||||||
|
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||||
|
id='',
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=message.function_call['name'],
|
||||||
|
arguments=message.function_call['arguments']
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -7,13 +7,23 @@ class MinimaxMessage:
|
||||||
USER = 'USER'
|
USER = 'USER'
|
||||||
ASSISTANT = 'BOT'
|
ASSISTANT = 'BOT'
|
||||||
SYSTEM = 'SYSTEM'
|
SYSTEM = 'SYSTEM'
|
||||||
|
FUNCTION = 'FUNCTION'
|
||||||
|
|
||||||
role: str = Role.USER.value
|
role: str = Role.USER.value
|
||||||
content: str
|
content: str
|
||||||
usage: Dict[str, int] = None
|
usage: Dict[str, int] = None
|
||||||
stop_reason: str = ''
|
stop_reason: str = ''
|
||||||
|
function_call: Dict[str, Any] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||||
|
return {
|
||||||
|
'sender_type': 'BOT',
|
||||||
|
'sender_name': '专家',
|
||||||
|
'text': '',
|
||||||
|
'function_call': self.function_call
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'sender_type': self.role,
|
'sender_type': self.role,
|
||||||
'sender_name': '我' if self.role == 'USER' else '专家',
|
'sender_name': '我' if self.role == 'USER' else '专家',
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 4096
|
context_size: 4096
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 16385
|
context_size: 16385
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 16385
|
context_size: 16385
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 16385
|
context_size: 16385
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 4096
|
context_size: 4096
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32768
|
context_size: 32768
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
|
|
@ -6,6 +6,7 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 8192
|
context_size: 8192
|
||||||
|
|
|
@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
if message.name is not None:
|
if message.name:
|
||||||
message_dict["name"] = message.name
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
|
@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||||
SystemPromptMessage, UserPromptMessage)
|
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||||
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
|
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
|
||||||
ParameterRule, ParameterType)
|
ParameterRule, ParameterType, ModelFeature)
|
||||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
|
from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
|
||||||
XinferenceModelExtraParameter)
|
XinferenceModelExtraParameter)
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
|
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
|
||||||
|
@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
||||||
"""
|
"""
|
||||||
|
if 'temperature' in model_parameters:
|
||||||
|
if model_parameters['temperature'] < 0.01:
|
||||||
|
model_parameters['temperature'] = 0.01
|
||||||
|
elif model_parameters['temperature'] > 1.0:
|
||||||
|
model_parameters['temperature'] = 0.99
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
||||||
tools=tools, stop=stop, stream=stream, user=user,
|
tools=tools, stop=stop, stream=stream, user=user,
|
||||||
|
@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
credentials['completion_type'] = 'completion'
|
credentials['completion_type'] = 'completion'
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
||||||
|
|
||||||
|
if extra_param.support_function_call:
|
||||||
|
credentials['support_function_call'] = True
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
||||||
|
@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
|
@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
zh_Hans='温度',
|
zh_Hans='温度',
|
||||||
en_US='Temperature'
|
en_US='Temperature'
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name='top_p',
|
||||||
|
@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
completion_type = LLMMode.COMPLETION.value
|
completion_type = LLMMode.COMPLETION.value
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||||
|
|
||||||
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
),
|
),
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
|
features=[
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
] if support_function_call else [],
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: completion_type,
|
ModelPropertyKey.MODE: completion_type,
|
||||||
},
|
},
|
||||||
|
@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
|
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
|
||||||
"""
|
"""
|
||||||
|
if 'server_url' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||||
|
|
||||||
|
if credentials['server_url'].endswith('/'):
|
||||||
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
base_url=f'{credentials["server_url"]}/v1',
|
base_url=f'{credentials["server_url"]}/v1',
|
||||||
api_key='abc',
|
api_key='abc',
|
||||||
|
|
|
@ -2,7 +2,7 @@ import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||||
|
@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
|
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||||
|
|
||||||
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
|
@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
server_url = credentials['server_url']
|
server_url = credentials['server_url']
|
||||||
model_uid = credentials['model_uid']
|
model_uid = credentials['model_uid']
|
||||||
|
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
client = Client(base_url=server_url)
|
client = Client(base_url=server_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
model_uid = credentials['model_uid']
|
||||||
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||||
|
|
||||||
|
if extra_args.max_tokens:
|
||||||
|
credentials['max_tokens'] = extra_args.max_tokens
|
||||||
|
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||||
except InvokeAuthorizationError:
|
except (InvokeAuthorizationError, RuntimeError):
|
||||||
raise CredentialsValidateFailedError('Invalid api key')
|
raise CredentialsValidateFailedError('Invalid api key')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
used to define customizable model schema
|
used to define customizable model schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
|
@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||||
),
|
),
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={},
|
model_properties={
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
|
||||||
|
},
|
||||||
parameter_rules=[]
|
parameter_rules=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from os import path
|
||||||
|
|
||||||
from requests import get
|
from requests import get
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
|
@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
|
||||||
model_format: str
|
model_format: str
|
||||||
model_handle_type: str
|
model_handle_type: str
|
||||||
model_ability: List[str]
|
model_ability: List[str]
|
||||||
|
max_tokens: int = 512
|
||||||
|
support_function_call: bool = False
|
||||||
|
|
||||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
|
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
|
||||||
|
support_function_call: bool, max_tokens: int) -> None:
|
||||||
self.model_format = model_format
|
self.model_format = model_format
|
||||||
self.model_handle_type = model_handle_type
|
self.model_handle_type = model_handle_type
|
||||||
self.model_ability = model_ability
|
self.model_ability = model_ability
|
||||||
|
self.support_function_call = support_function_call
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
cache_lock = Lock()
|
cache_lock = Lock()
|
||||||
|
@ -49,7 +55,7 @@ class XinferenceHelper:
|
||||||
get xinference model extra parameter like model_format and model_handle_type
|
get xinference model extra parameter like model_format and model_handle_type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = f'{server_url}/v1/models/{model_uid}'
|
url = path.join(server_url, 'v1/models', model_uid)
|
||||||
|
|
||||||
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||||
session = Session()
|
session = Session()
|
||||||
|
@ -66,10 +72,12 @@ class XinferenceHelper:
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
model_format = response_json['model_format']
|
model_format = response_json.get('model_format', 'ggmlv3')
|
||||||
model_ability = response_json['model_ability']
|
model_ability = response_json.get('model_ability', [])
|
||||||
|
|
||||||
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
if response_json.get('model_type') == 'embedding':
|
||||||
|
model_handle_type = 'embedding'
|
||||||
|
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
||||||
model_handle_type = 'chatglm'
|
model_handle_type = 'chatglm'
|
||||||
elif 'generate' in model_ability:
|
elif 'generate' in model_ability:
|
||||||
model_handle_type = 'generate'
|
model_handle_type = 'generate'
|
||||||
|
@ -78,8 +86,13 @@ class XinferenceHelper:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
||||||
|
|
||||||
|
support_function_call = 'tools' in model_ability
|
||||||
|
max_tokens = response_json.get('max_tokens', 512)
|
||||||
|
|
||||||
return XinferenceModelExtraParameter(
|
return XinferenceModelExtraParameter(
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
model_handle_type=model_handle_type,
|
model_handle_type=model_handle_type,
|
||||||
model_ability=model_ability
|
model_ability=model_ability,
|
||||||
|
support_function_call=support_function_call,
|
||||||
|
max_tokens=max_tokens
|
||||||
)
|
)
|
|
@ -2,6 +2,10 @@ model: glm-3-turbo
|
||||||
label:
|
label:
|
||||||
en_US: glm-3-turbo
|
en_US: glm-3-turbo
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
|
|
|
@ -2,6 +2,10 @@ model: glm-4
|
||||||
label:
|
label:
|
||||||
en_US: glm-4
|
en_US: glm-4
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
|
|
|
@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
'content': prompt_message.content,
|
'content': prompt_message.content,
|
||||||
'tool_call_id': prompt_message.tool_call_id
|
'tool_call_id': prompt_message.tool_call_id
|
||||||
})
|
})
|
||||||
|
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||||
|
if prompt_message.tool_calls:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': prompt_message.content,
|
||||||
|
'tool_calls': [
|
||||||
|
{
|
||||||
|
'id': tool_call.id,
|
||||||
|
'type': tool_call.type,
|
||||||
|
'function': {
|
||||||
|
'name': tool_call.function.name,
|
||||||
|
'arguments': tool_call.function.arguments
|
||||||
|
}
|
||||||
|
} for tool_call in prompt_message.tool_calls
|
||||||
|
]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': prompt_message.content
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
params['messages'].append({
|
params['messages'].append({
|
||||||
'role': prompt_message.role.value,
|
'role': prompt_message.role.value,
|
||||||
|
|
|
@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
|
||||||
huggingface_hub~=0.16.4
|
huggingface_hub~=0.16.4
|
||||||
transformers~=4.31.0
|
transformers~=4.31.0
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
xinference-client~=0.6.4
|
xinference-client~=0.8.1
|
||||||
safetensors==0.3.2
|
safetensors==0.3.2
|
||||||
zhipuai==1.0.7
|
zhipuai==1.0.7
|
||||||
werkzeug~=3.0.1
|
werkzeug~=3.0.1
|
||||||
|
|
|
@ -19,58 +19,86 @@ class MockXinferenceClass(object):
|
||||||
raise RuntimeError('404 Not Found')
|
raise RuntimeError('404 Not Found')
|
||||||
|
|
||||||
if 'generate' == model_uid:
|
if 'generate' == model_uid:
|
||||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
|
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||||
if 'chat' == model_uid:
|
if 'chat' == model_uid:
|
||||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
|
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||||
if 'embedding' == model_uid:
|
if 'embedding' == model_uid:
|
||||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
|
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||||
if 'rerank' == model_uid:
|
if 'rerank' == model_uid:
|
||||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
|
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||||
raise RuntimeError('404 Not Found')
|
raise RuntimeError('404 Not Found')
|
||||||
|
|
||||||
def get(self: Session, url: str, **kwargs):
|
def get(self: Session, url: str, **kwargs):
|
||||||
if '/v1/models/' in url:
|
response = Response()
|
||||||
response = Response()
|
if 'v1/models/' in url:
|
||||||
|
|
||||||
# get model uid
|
# get model uid
|
||||||
model_uid = url.split('/')[-1]
|
model_uid = url.split('/')[-1]
|
||||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||||
response.status_code = 404
|
response.status_code = 404
|
||||||
raise ConnectionError('404 Not Found')
|
return response
|
||||||
|
|
||||||
# check if url is valid
|
# check if url is valid
|
||||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||||
response.status_code = 404
|
response.status_code = 404
|
||||||
raise ConnectionError('404 Not Found')
|
return response
|
||||||
|
|
||||||
|
if model_uid in ['generate', 'chat']:
|
||||||
|
response.status_code = 200
|
||||||
|
response._content = b'''{
|
||||||
|
"model_type": "LLM",
|
||||||
|
"address": "127.0.0.1:43877",
|
||||||
|
"accelerators": [
|
||||||
|
"0",
|
||||||
|
"1"
|
||||||
|
],
|
||||||
|
"model_name": "chatglm3-6b",
|
||||||
|
"model_lang": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"model_ability": [
|
||||||
|
"generate",
|
||||||
|
"chat"
|
||||||
|
],
|
||||||
|
"model_description": "latest chatglm3",
|
||||||
|
"model_format": "pytorch",
|
||||||
|
"model_size_in_billions": 7,
|
||||||
|
"quantization": "none",
|
||||||
|
"model_hub": "huggingface",
|
||||||
|
"revision": null,
|
||||||
|
"context_length": 2048,
|
||||||
|
"replica": 1
|
||||||
|
}'''
|
||||||
|
return response
|
||||||
|
|
||||||
|
elif model_uid == 'embedding':
|
||||||
|
response.status_code = 200
|
||||||
|
response._content = b'''{
|
||||||
|
"model_type": "embedding",
|
||||||
|
"address": "127.0.0.1:43877",
|
||||||
|
"accelerators": [
|
||||||
|
"0",
|
||||||
|
"1"
|
||||||
|
],
|
||||||
|
"model_name": "bge",
|
||||||
|
"model_lang": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"revision": null,
|
||||||
|
"max_tokens": 512
|
||||||
|
}'''
|
||||||
|
return response
|
||||||
|
|
||||||
|
elif 'v1/cluster/auth' in url:
|
||||||
response.status_code = 200
|
response.status_code = 200
|
||||||
response._content = b'''{
|
response._content = b'''{
|
||||||
"model_type": "LLM",
|
"auth": true
|
||||||
"address": "127.0.0.1:43877",
|
|
||||||
"accelerators": [
|
|
||||||
"0",
|
|
||||||
"1"
|
|
||||||
],
|
|
||||||
"model_name": "chatglm3-6b",
|
|
||||||
"model_lang": [
|
|
||||||
"en"
|
|
||||||
],
|
|
||||||
"model_ability": [
|
|
||||||
"generate",
|
|
||||||
"chat"
|
|
||||||
],
|
|
||||||
"model_description": "latest chatglm3",
|
|
||||||
"model_format": "pytorch",
|
|
||||||
"model_size_in_billions": 7,
|
|
||||||
"quantization": "none",
|
|
||||||
"model_hub": "huggingface",
|
|
||||||
"revision": null,
|
|
||||||
"context_length": 2048,
|
|
||||||
"replica": 1
|
|
||||||
}'''
|
}'''
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _check_cluster_authenticated(self):
|
||||||
|
self._cluster_authed = True
|
||||||
|
|
||||||
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
|
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
|
||||||
# check if self._model_uid is a valid uuid
|
# check if self._model_uid is a valid uuid
|
||||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||||
|
@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||||
if MOCK:
|
if MOCK:
|
||||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||||
|
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
|
||||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user