import json import logging import uuid from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, PromptMessageContent, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.entities.tool_entities import ( ToolParameter, ToolRuntimeVariablePool, ) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager from extensions.ext_database import db from factories import file_factory from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) class BaseAgentRunner(AppRunner): def __init__( self, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, app_config: AgentChatAppConfig, model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, queue_manager: AppQueueManager, message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, model_instance: ModelInstance = None, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity self.conversation = conversation self.app_config = app_config self.model_config = model_config self.config = config self.queue_manager = queue_manager self.message = message self.user_id = user_id self.memory = memory self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance # init callback self.agent_callback = DifyAgentCallbackHandler() # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, app_id=self.app_config.app_id, message_id=message.id, user_id=user_id, invoke_from=self.application_generate_entity.invoke_from, ) self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( tenant_id=tenant_id, dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback, ) # get how many agent thoughts have been created self.agent_thought_count = ( db.session.query(MessageAgentThought) .filter( MessageAgentThought.message_id == self.message.id, ) .count() ) db.session.close() # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): self.stream_tool_call = True else: self.stream_tool_call = False # check if model supports vision if model_schema and ModelFeature.VISION in (model_schema.features or []): self.files = application_generate_entity.files else: self.files = [] self.query = None self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( self, app_generate_entity: AgentChatAppGenerateEntity ) -> AgentChatAppGenerateEntity: """ Repack app generate entity """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: app_generate_entity.app_config.prompt_template.simple_prompt_template = "" return app_generate_entity def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ convert tool to prompt message tool """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, invoke_from=self.application_generate_entity.invoke_from, ) tool_entity.load_variables(self.variables_pool) message_tool = PromptMessageTool( name=tool.tool_name, description=tool_entity.description.llm, parameters={ "type": "object", "properties": {}, "required": [], }, ) parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: continue parameter_type = parameter.type.as_normal_type() if parameter.type in { ToolParameter.ToolParameterType.SYSTEM_FILES, ToolParameter.ToolParameterType.FILE, ToolParameter.ToolParameterType.FILES, }: continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } if len(enum) > 0: message_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: message_tool.parameters["required"].append(parameter.name) return message_tool, tool_entity def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: """ convert dataset retriever tool to prompt message tool """ prompt_tool = PromptMessageTool( name=tool.identity.name, description=tool.description.llm, parameters={ "type": "object", "properties": {}, "required": [], }, ) for parameter in tool.get_runtime_parameters(): parameter_type = "string" prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } if parameter.required: if parameter.name not in prompt_tool.parameters["required"]: prompt_tool.parameters["required"].append(parameter.name) return prompt_tool def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: """ Init tools """ tool_instances = {} prompt_messages_tools = [] for tool in self.app_config.agent.tools if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: # api tool may be deleted continue # save tool entity tool_instances[tool.tool_name] = tool_entity # save prompt tool prompt_messages_tools.append(prompt_tool) # convert dataset tools into ModelRuntime Tool format for dataset_tool in self.dataset_tools: prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity tool_instances[dataset_tool.identity.name] = dataset_tool return tool_instances, prompt_messages_tools def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: """ update prompt message tool """ # try to get tool runtime parameters tool_runtime_parameters = tool.get_runtime_parameters() or [] for parameter in tool_runtime_parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: continue parameter_type = parameter.type.as_normal_type() if parameter.type in { ToolParameter.ToolParameterType.SYSTEM_FILES, ToolParameter.ToolParameterType.FILE, ToolParameter.ToolParameterType.FILES, }: continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } if len(enum) > 0: prompt_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: if parameter.name not in prompt_tool.parameters["required"]: prompt_tool.parameters["required"].append(parameter.name) return prompt_tool def create_agent_thought( self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] ) -> MessageAgentThought: """ Create agent thought """ thought = MessageAgentThought( message_id=message_id, message_chain_id=None, thought="", tool=tool_name, tool_labels_str="{}", tool_meta_str="{}", tool_input=tool_input, message=message, message_token=0, message_unit_price=0, message_price_unit=0, message_files=json.dumps(messages_ids) if messages_ids else "", answer="", observation="", answer_token=0, answer_unit_price=0, answer_price_unit=0, tokens=0, total_price=0, position=self.agent_thought_count + 1, currency="USD", latency=0, created_by_role="account", created_by=self.user_id, ) db.session.add(thought) db.session.commit() db.session.refresh(thought) db.session.close() self.agent_thought_count += 1 return thought def save_agent_thought( self, agent_thought: MessageAgentThought, tool_name: str, tool_input: Union[str, dict], thought: str, observation: Union[str, dict], tool_invoke_meta: Union[str, dict], answer: str, messages_ids: list[str], llm_usage: LLMUsage = None, ) -> MessageAgentThought: """ Save agent thought """ agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() if thought is not None: agent_thought.thought = thought if tool_name is not None: agent_thought.tool = tool_name if tool_input is not None: if isinstance(tool_input, dict): try: tool_input = json.dumps(tool_input, ensure_ascii=False) except Exception as e: tool_input = json.dumps(tool_input) agent_thought.tool_input = tool_input if observation is not None: if isinstance(observation, dict): try: observation = json.dumps(observation, ensure_ascii=False) except Exception as e: observation = json.dumps(observation) agent_thought.observation = observation if answer is not None: agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: agent_thought.message_files = json.dumps(messages_ids) if llm_usage: agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_price_unit = llm_usage.prompt_price_unit agent_thought.message_unit_price = llm_usage.prompt_unit_price agent_thought.answer_token = llm_usage.completion_tokens agent_thought.answer_price_unit = llm_usage.completion_price_unit agent_thought.answer_unit_price = llm_usage.completion_unit_price agent_thought.tokens = llm_usage.total_tokens agent_thought.total_price = llm_usage.total_price # check if tool labels is not empty labels = agent_thought.tool_labels or {} tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue if tool not in labels: tool_label = ToolManager.get_tool_label(tool) if tool_label: labels[tool] = tool_label.to_dict() else: labels[tool] = {"en_US": tool, "zh_Hans": tool} agent_thought.tool_labels_str = json.dumps(labels) if tool_invoke_meta is not None: if isinstance(tool_invoke_meta, dict): try: tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) except Exception as e: tool_invoke_meta = json.dumps(tool_invoke_meta) agent_thought.tool_meta_str = tool_invoke_meta db.session.commit() db.session.close() def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables """ db_variables = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == self.message.conversation_id, ) .first() ) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize agent history """ result = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) messages: list[Message] = ( db.session.query(Message) .filter( Message.conversation_id == self.message.conversation_id, ) .order_by(Message.created_at.desc()) .all() ) messages = list(reversed(extract_thread_messages(messages))) for message in messages: if message.id == self.message.id: continue result.append(self.organize_agent_user_prompt(message)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: tools = agent_thought.tool if tools: tools = tools.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) except Exception as e: tool_responses = dict.fromkeys(tools, agent_thought.observation) for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool, arguments=json.dumps(tool_inputs.get(tool, {})), ), ) ) tool_call_response.append( ToolPromptMessage( content=tool_responses.get(tool, agent_thought.observation), name=tool, tool_call_id=tool_call_id, ) ) result.extend( [ AssistantPromptMessage( content=agent_thought.thought, tool_calls=tool_calls, ), *tool_call_response, ] ) if not tools: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: if message.answer: result.append(AssistantPromptMessage(content=message.answer)) db.session.close() return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=self.tenant_id, config=file_extra_config ) else: file_objs = [] if not file_objs: return UserPromptMessage(content=message.query) else: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file_obj in file_objs: prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) return UserPromptMessage(content=prompt_message_contents) else: return UserPromptMessage(content=message.query)