dify/api/core/agent/cot_agent_runner.py

422 lines
17 KiB
Python
Raw Normal View History

import json
2024-04-11 18:34:17 +08:00
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
2024-04-11 18:34:17 +08:00
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
2024-02-01 18:11:57 +08:00
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
2024-04-11 18:34:17 +08:00
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from models.model import Message
2024-02-01 18:11:57 +08:00
2024-04-11 18:34:17 +08:00
class CotAgentRunner(BaseAgentRunner, ABC):
2024-03-05 14:02:07 +08:00
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
2024-04-11 18:34:17 +08:00
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
2024-03-05 14:02:07 +08:00
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
2024-04-11 18:34:17 +08:00
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
2024-04-11 18:34:17 +08:00
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# convert tools into ModelRuntime Tool format
2024-04-11 18:34:17 +08:00
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
2024-01-30 15:25:37 +08:00
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
2024-04-11 18:34:17 +08:00
self._prompt_messages_tools = []
message_file_ids = []
2024-01-24 12:09:30 +08:00
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
2024-01-24 12:09:30 +08:00
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
2024-03-04 14:15:53 +08:00
# recalc llm max tokens
2024-04-11 18:34:17 +08:00
prompt_messages = self._organize_prompt_messages()
2024-03-04 13:32:17 +08:00
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
2024-02-21 10:45:59 +08:00
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
2024-02-21 10:45:59 +08:00
stream=True,
user=self.user_id,
callbacks=[],
)
# check llm result
2024-02-21 10:45:59 +08:00
if not chunks:
raise ValueError("failed to invoke llm")
2024-02-21 10:45:59 +08:00
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
2024-02-21 10:45:59 +08:00
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
2024-02-21 10:45:59 +08:00
)
2024-01-24 12:09:30 +08:00
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
2024-01-24 12:09:30 +08:00
2024-02-21 10:45:59 +08:00
for chunk in react_chunks:
2024-04-11 18:34:17 +08:00
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
2024-04-11 18:34:17 +08:00
scratchpad.action = action
2024-02-21 10:45:59 +08:00
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
2024-02-21 10:45:59 +08:00
)
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
2024-04-11 18:34:17 +08:00
self._agent_scratchpad.append(scratchpad)
2024-02-21 10:45:59 +08:00
# get llm usage
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
2024-02-21 10:45:59 +08:00
else:
usage_dict["usage"] = LLMUsage.empty_usage()
2024-04-11 18:34:17 +08:00
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
2024-04-11 18:34:17 +08:00
tool_invoke_meta={},
thought=scratchpad.thought,
observation="",
2024-04-11 18:34:17 +08:00
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict["usage"],
2024-04-11 18:34:17 +08:00
)
2024-04-11 18:34:17 +08:00
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
2024-04-11 18:34:17 +08:00
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
2024-04-11 18:34:17 +08:00
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
2024-04-11 18:34:17 +08:00
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
2024-04-11 18:34:17 +08:00
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
2024-04-11 18:34:17 +08:00
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
2024-04-11 18:34:17 +08:00
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
2024-04-11 18:34:17 +08:00
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
2024-04-11 18:34:17 +08:00
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
2024-04-11 18:34:17 +08:00
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
2024-04-11 18:34:17 +08:00
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
2024-08-16 14:19:01 +08:00
:param message_file_ids: message file ids
:param trace_manager: trace manager
2024-04-11 18:34:17 +08:00
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
2024-04-11 18:34:17 +08:00
if isinstance(tool_call_args, str):
2024-02-21 10:45:59 +08:00
try:
2024-04-11 18:34:17 +08:00
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
2024-04-11 18:34:17 +08:00
)
2024-02-21 10:45:59 +08:00
2024-04-11 18:34:17 +08:00
# publish files
for message_file_id, save_as in message_files:
2024-04-11 18:34:17 +08:00
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
2024-04-11 18:34:17 +08:00
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
2024-04-11 18:34:17 +08:00
# add message file ids
message_file_ids.append(message_file_id)
2024-04-11 18:34:17 +08:00
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
2024-02-21 10:45:59 +08:00
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
continue
return instruction
2024-04-11 18:34:17 +08:00
def _init_react_state(self, query) -> None:
"""
init agent scratchpad
"""
2024-04-11 18:34:17 +08:00
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
2024-04-11 18:34:17 +08:00
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
2024-04-11 18:34:17 +08:00
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
2024-04-11 18:34:17 +08:00
"""
message = ""
2024-04-11 18:34:17 +08:00
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
2024-04-11 18:34:17 +08:00
"""
organize historic prompt messages
2024-04-11 18:34:17 +08:00
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
2024-04-11 18:34:17 +08:00
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
2024-04-11 18:34:17 +08:00
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
2024-04-11 18:34:17 +08:00
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
).get_prompt()
return historic_prompts