From 38e5952417732aa14763f6b2d5bfde1aae92feb0 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:02:07 +0800 Subject: [PATCH] Fix/agent react output parser (#2689) --- api/core/features/assistant_cot_runner.py | 34 +++++++++++++++-------- api/core/tools/tool/tool.py | 13 ++++++++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 8fcbff983d..3762ddcf62 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -28,6 +28,9 @@ from models.model import Conversation, Message class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): + _is_first_iteration = True + _ignore_observation_providers = ['wenxin'] + def run(self, conversation: Conversation, message: Message, query: str, @@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) - # check model mode - if self.app_orchestration_config.model_config.mode == "completion": - # TODO: stop words - if 'Observation' not in app_orchestration_config.model_config.stop: + if 'Observation' not in app_orchestration_config.model_config.stop: + if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: app_orchestration_config.model_config.stop.append('Observation') # override inputs @@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ) ) + scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' agent_scratchpad.append(scratchpad) # get llm usage @@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): # invoke tool error_response = None try: + if isinstance(tool_call_args, str): + try: + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass + tool_response = tool_instance.invoke( user_id=self.user_id, - tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args) + tool_parameters=tool_call_args ) # transform tool response to llm friendly response tool_response = self.transform_tool_invoke_messages(tool_response) @@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content, + thought=message.content or 'I am thinking about how to help you', action_str='', action=None, observation=None, @@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): result = '' for scratchpad in agent_scratchpad: - result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n" + result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \ + next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available') return result @@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): )) # add assistant message - if len(agent_scratchpad) > 0: + if len(agent_scratchpad) > 0 and not self._is_first_iteration: prompt_messages.append(AssistantPromptMessage( - content=(agent_scratchpad[-1].thought or '') + content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''), )) # add user message - if len(agent_scratchpad) > 0: + if len(agent_scratchpad) > 0 and not self._is_first_iteration: prompt_messages.append(UserPromptMessage( - content=(agent_scratchpad[-1].observation or ''), + content=(agent_scratchpad[-1].observation or 'It seems that no response is available'), )) + self._is_first_iteration = False + return prompt_messages elif mode == "completion": # parse agent scratchpad agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad) + self._is_first_iteration = False # parse prompt messages return [UserPromptMessage( content=first_prompt.replace("{{instruction}}", instruction) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 9f343d6000..192793897e 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -174,7 +174,18 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]: + # check if tool_parameters is a string + if isinstance(tool_parameters, str): + # check if this tool has only one parameter + parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM] + if parameters and len(parameters) == 1: + tool_parameters = { + parameters[0].name: tool_parameters + } + else: + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + # update tool_parameters if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters)