mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Fix/agent react output parser (#2689)
This commit is contained in:
parent
7f891939f1
commit
38e5952417
|
@ -28,6 +28,9 @@ from models.model import Conversation, Message
|
||||||
|
|
||||||
|
|
||||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
|
_is_first_iteration = True
|
||||||
|
_ignore_observation_providers = ['wenxin']
|
||||||
|
|
||||||
def run(self, conversation: Conversation,
|
def run(self, conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||||
|
|
||||||
# check model mode
|
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||||
if self.app_orchestration_config.model_config.mode == "completion":
|
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
|
||||||
# TODO: stop words
|
|
||||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
|
||||||
app_orchestration_config.model_config.stop.append('Observation')
|
app_orchestration_config.model_config.stop.append('Observation')
|
||||||
|
|
||||||
# override inputs
|
# override inputs
|
||||||
|
@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||||
agent_scratchpad.append(scratchpad)
|
agent_scratchpad.append(scratchpad)
|
||||||
|
|
||||||
# get llm usage
|
# get llm usage
|
||||||
|
@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
# invoke tool
|
# invoke tool
|
||||||
error_response = None
|
error_response = None
|
||||||
try:
|
try:
|
||||||
|
if isinstance(tool_call_args, str):
|
||||||
|
try:
|
||||||
|
tool_call_args = json.loads(tool_call_args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
tool_response = tool_instance.invoke(
|
tool_response = tool_instance.invoke(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
tool_parameters=tool_call_args
|
||||||
)
|
)
|
||||||
# transform tool response to llm friendly response
|
# transform tool response to llm friendly response
|
||||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||||
|
@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
if isinstance(message, AssistantPromptMessage):
|
if isinstance(message, AssistantPromptMessage):
|
||||||
current_scratchpad = AgentScratchpadUnit(
|
current_scratchpad = AgentScratchpadUnit(
|
||||||
agent_response=message.content,
|
agent_response=message.content,
|
||||||
thought=message.content,
|
thought=message.content or 'I am thinking about how to help you',
|
||||||
action_str='',
|
action_str='',
|
||||||
action=None,
|
action=None,
|
||||||
observation=None,
|
observation=None,
|
||||||
|
@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
|
|
||||||
result = ''
|
result = ''
|
||||||
for scratchpad in agent_scratchpad:
|
for scratchpad in agent_scratchpad:
|
||||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||||
|
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
))
|
))
|
||||||
|
|
||||||
# add assistant message
|
# add assistant message
|
||||||
if len(agent_scratchpad) > 0:
|
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||||
prompt_messages.append(AssistantPromptMessage(
|
prompt_messages.append(AssistantPromptMessage(
|
||||||
content=(agent_scratchpad[-1].thought or '')
|
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||||
))
|
))
|
||||||
|
|
||||||
# add user message
|
# add user message
|
||||||
if len(agent_scratchpad) > 0:
|
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||||
prompt_messages.append(UserPromptMessage(
|
prompt_messages.append(UserPromptMessage(
|
||||||
content=(agent_scratchpad[-1].observation or ''),
|
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
self._is_first_iteration = False
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
elif mode == "completion":
|
elif mode == "completion":
|
||||||
# parse agent scratchpad
|
# parse agent scratchpad
|
||||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||||
|
self._is_first_iteration = False
|
||||||
# parse prompt messages
|
# parse prompt messages
|
||||||
return [UserPromptMessage(
|
return [UserPromptMessage(
|
||||||
content=first_prompt.replace("{{instruction}}", instruction)
|
content=first_prompt.replace("{{instruction}}", instruction)
|
||||||
|
|
|
@ -174,7 +174,18 @@ class Tool(BaseModel, ABC):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
|
||||||
|
# check if tool_parameters is a string
|
||||||
|
if isinstance(tool_parameters, str):
|
||||||
|
# check if this tool has only one parameter
|
||||||
|
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
|
||||||
|
if parameters and len(parameters) == 1:
|
||||||
|
tool_parameters = {
|
||||||
|
parameters[0].name: tool_parameters
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||||
|
|
||||||
# update tool_parameters
|
# update tool_parameters
|
||||||
if self.runtime.runtime_parameters:
|
if self.runtime.runtime_parameters:
|
||||||
tool_parameters.update(self.runtime.runtime_parameters)
|
tool_parameters.update(self.runtime.runtime_parameters)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user