Refactor: CoT runner (#2157)

This commit is contained in:
Yeuoly 2024-01-24 12:09:30 +08:00 committed by GitHub
parent c8fb619d37
commit 48d5628fd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 15 deletions

View File

@ -19,8 +19,6 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message
logger = logging.getLogger(__name__)
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
@ -93,6 +91,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_messages_tools = []
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
@ -100,6 +99,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
tool_input='',
messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
@ -138,6 +139,10 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if llm_result.usage:
increse_usage(llm_usage, llm_result.usage)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
@ -187,7 +192,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
tool_call_args = scratchpad.action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
logger.error(f"failed to find tool instance: {tool_call_name}")
answer = f"there is not a tool named {tool_call_name}"
self.save_agent_thought(agent_thought=agent_thought,
tool_name='',
@ -237,7 +241,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if error_response:
observation = error_response
logger.error(error_response)
else:
observation = self._convert_tool_response_to_str(tool_response)
@ -543,13 +546,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# add assistant message
if len(agent_scratchpad) > 0:
prompt_messages.append(AssistantPromptMessage(
content=(agent_scratchpad[-1].thought or '') + "\n" + (agent_scratchpad[-1].observation or '')
content=(agent_scratchpad[-1].thought or '')
))
# add user message
if len(agent_scratchpad) > 0:
prompt_messages.append(UserPromptMessage(
content=input,
content=(agent_scratchpad[-1].observation or ''),
))
return prompt_messages

View File

@ -2,13 +2,13 @@ identity:
author: Dify
name: youtube
label:
en_US: Youtube
zh_Hans: Youtube
pt_BR: Youtube
en_US: YouTube
zh_Hans: YouTube
pt_BR: YouTube
description:
en_US: Youtube
zh_Hans: Youtube油管是全球最大的视频分享网站用户可以在上面上传、观看和分享视频。
pt_BR: Youtube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
en_US: YouTube
zh_Hans: YouTube油管是全球最大的视频分享网站用户可以在上面上传、观看和分享视频。
pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
icon: icon.png
credentials_for_provider:
google_api_key: