mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
fix organize agent's history messages without recalculating tokens (#4324)
Co-authored-by: chenyongzhao <chenyz@mama.cn>
This commit is contained in:
parent
74f38eacda
commit
afed3610fc
|
@ -128,6 +128,8 @@ class BaseAgentRunner(AppRunner):
|
||||||
self.files = application_generate_entity.files
|
self.files = application_generate_entity.files
|
||||||
else:
|
else:
|
||||||
self.files = []
|
self.files = []
|
||||||
|
self.query = None
|
||||||
|
self._current_thoughts: list[PromptMessage] = []
|
||||||
|
|
||||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||||
-> AgentChatAppGenerateEntity:
|
-> AgentChatAppGenerateEntity:
|
||||||
|
@ -545,3 +547,4 @@ class BaseAgentRunner(AppRunner):
|
||||||
return UserPromptMessage(content=prompt_message_contents)
|
return UserPromptMessage(content=prompt_message_contents)
|
||||||
else:
|
else:
|
||||||
return UserPromptMessage(content=message.query)
|
return UserPromptMessage(content=message.query)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
|
@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
|
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
organize historic prompt messages
|
organize historic prompt messages
|
||||||
"""
|
"""
|
||||||
|
@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
scratchpad: list[AgentScratchpadUnit] = []
|
scratchpad: list[AgentScratchpadUnit] = []
|
||||||
current_scratchpad: AgentScratchpadUnit = None
|
current_scratchpad: AgentScratchpadUnit = None
|
||||||
|
|
||||||
|
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||||
|
model_config=self.model_config,
|
||||||
|
prompt_messages=current_session_messages or [],
|
||||||
|
history_messages=self.history_prompt_messages,
|
||||||
|
memory=self.memory
|
||||||
|
).get_prompt()
|
||||||
|
|
||||||
for message in self.history_prompt_messages:
|
for message in self.history_prompt_messages:
|
||||||
if isinstance(message, AssistantPromptMessage):
|
if isinstance(message, AssistantPromptMessage):
|
||||||
current_scratchpad = AgentScratchpadUnit(
|
current_scratchpad = AgentScratchpadUnit(
|
||||||
|
|
|
@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||||
# organize system prompt
|
# organize system prompt
|
||||||
system_message = self._organize_system_prompt()
|
system_message = self._organize_system_prompt()
|
||||||
|
|
||||||
# organize historic prompt messages
|
|
||||||
historic_messages = self._historic_prompt_messages
|
|
||||||
|
|
||||||
# organize current assistant messages
|
# organize current assistant messages
|
||||||
agent_scratchpad = self._agent_scratchpad
|
agent_scratchpad = self._agent_scratchpad
|
||||||
if not agent_scratchpad:
|
if not agent_scratchpad:
|
||||||
|
@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||||
query_messages = UserPromptMessage(content=self._query)
|
query_messages = UserPromptMessage(content=self._query)
|
||||||
|
|
||||||
if assistant_messages:
|
if assistant_messages:
|
||||||
|
# organize historic prompt messages
|
||||||
|
historic_messages = self._organize_historic_prompt_messages([
|
||||||
|
system_message,
|
||||||
|
query_messages,
|
||||||
|
*assistant_messages,
|
||||||
|
UserPromptMessage(content='continue')
|
||||||
|
])
|
||||||
messages = [
|
messages = [
|
||||||
system_message,
|
system_message,
|
||||||
*historic_messages,
|
*historic_messages,
|
||||||
|
@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||||
UserPromptMessage(content='continue')
|
UserPromptMessage(content='continue')
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
# organize historic prompt messages
|
||||||
|
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
|
||||||
messages = [system_message, *historic_messages, query_messages]
|
messages = [system_message, *historic_messages, query_messages]
|
||||||
|
|
||||||
# join all messages
|
# join all messages
|
||||||
|
|
|
@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||||
|
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|
||||||
def _organize_historic_prompt(self) -> str:
|
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Organize historic prompt
|
Organize historic prompt
|
||||||
"""
|
"""
|
||||||
historic_prompt_messages = self._historic_prompt_messages
|
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||||
historic_prompt = ""
|
historic_prompt = ""
|
||||||
|
|
||||||
for message in historic_prompt_messages:
|
for message in historic_prompt_messages:
|
||||||
|
|
|
@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
@ -24,21 +25,18 @@ from models.model import Message
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
|
|
||||||
def run(self,
|
def run(self,
|
||||||
message: Message, query: str, **kwargs: Any
|
message: Message, query: str, **kwargs: Any
|
||||||
) -> Generator[LLMResultChunk, None, None]:
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
"""
|
"""
|
||||||
Run FunctionCall agent application
|
Run FunctionCall agent application
|
||||||
"""
|
"""
|
||||||
|
self.query = query
|
||||||
app_generate_entity = self.application_generate_entity
|
app_generate_entity = self.application_generate_entity
|
||||||
|
|
||||||
app_config = self.app_config
|
app_config = self.app_config
|
||||||
|
|
||||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
|
||||||
prompt_messages = self.history_prompt_messages
|
|
||||||
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
|
||||||
prompt_messages = self._organize_user_query(query, prompt_messages)
|
|
||||||
|
|
||||||
# convert tools into ModelRuntime Tool format
|
# convert tools into ModelRuntime Tool format
|
||||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||||
|
|
||||||
|
@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
)
|
)
|
||||||
|
|
||||||
# recalc llm max tokens
|
# recalc llm max tokens
|
||||||
|
prompt_messages = self._organize_prompt_messages()
|
||||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||||
# invoke model
|
# invoke model
|
||||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||||
|
@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
else:
|
else:
|
||||||
assistant_message.content = response
|
assistant_message.content = response
|
||||||
|
|
||||||
prompt_messages.append(assistant_message)
|
self._current_thoughts.append(assistant_message)
|
||||||
|
|
||||||
# save thought
|
# save thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
|
@ -265,11 +264,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_responses.append(tool_response)
|
tool_responses.append(tool_response)
|
||||||
prompt_messages = self._organize_assistant_message(
|
if tool_response['tool_response'] is not None:
|
||||||
|
self._current_thoughts.append(
|
||||||
|
ToolPromptMessage(
|
||||||
|
content=tool_response['tool_response'],
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
tool_call_name=tool_call_name,
|
name=tool_call_name,
|
||||||
tool_response=tool_response['tool_response'],
|
)
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(tool_responses) > 0:
|
if len(tool_responses) > 0:
|
||||||
|
@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
|
|
||||||
iteration_step += 1
|
iteration_step += 1
|
||||||
|
|
||||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
|
||||||
|
|
||||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||||
# publish end event
|
# publish end event
|
||||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||||
|
@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
|
||||||
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
||||||
"""
|
|
||||||
Organize assistant message
|
|
||||||
"""
|
|
||||||
prompt_messages = deepcopy(prompt_messages)
|
|
||||||
|
|
||||||
if tool_response is not None:
|
|
||||||
prompt_messages.append(
|
|
||||||
ToolPromptMessage(
|
|
||||||
content=tool_response,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
name=tool_call_name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt_messages
|
|
||||||
|
|
||||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
As for now, gpt supports both fc and vision at the first iteration.
|
As for now, gpt supports both fc and vision at the first iteration.
|
||||||
|
@ -429,3 +410,25 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
])
|
])
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _organize_prompt_messages(self):
|
||||||
|
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||||
|
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||||
|
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||||
|
|
||||||
|
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||||
|
model_config=self.model_config,
|
||||||
|
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||||
|
history_messages=self.history_prompt_messages,
|
||||||
|
memory=self.memory
|
||||||
|
).get_prompt()
|
||||||
|
|
||||||
|
prompt_messages = [
|
||||||
|
*self.history_prompt_messages,
|
||||||
|
*query_prompt_messages,
|
||||||
|
*self._current_thoughts
|
||||||
|
]
|
||||||
|
if len(self._current_thoughts) != 0:
|
||||||
|
# clear messages after the first iteration
|
||||||
|
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||||
|
return prompt_messages
|
||||||
|
|
82
api/core/prompt/agent_history_prompt_transform.py
Normal file
82
api/core/prompt/agent_history_prompt_transform.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
ModelConfigWithCredentialsEntity,
|
||||||
|
)
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessage,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
|
|
||||||
|
|
||||||
|
class AgentHistoryPromptTransform(PromptTransform):
|
||||||
|
"""
|
||||||
|
History Prompt Transform for Agent App
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
history_messages: list[PromptMessage],
|
||||||
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.prompt_messages = prompt_messages
|
||||||
|
self.history_messages = history_messages
|
||||||
|
self.memory = memory
|
||||||
|
|
||||||
|
def get_prompt(self) -> list[PromptMessage]:
|
||||||
|
prompt_messages = []
|
||||||
|
num_system = 0
|
||||||
|
for prompt_message in self.history_messages:
|
||||||
|
if isinstance(prompt_message, SystemPromptMessage):
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
num_system += 1
|
||||||
|
|
||||||
|
if not self.memory:
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||||
|
|
||||||
|
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
|
self.memory.model_instance.model,
|
||||||
|
self.memory.model_instance.credentials,
|
||||||
|
self.history_messages
|
||||||
|
)
|
||||||
|
if curr_message_tokens <= max_token_limit:
|
||||||
|
return self.history_messages
|
||||||
|
|
||||||
|
# number of prompt has been appended in current message
|
||||||
|
num_prompt = 0
|
||||||
|
# append prompt messages in desc order
|
||||||
|
for prompt_message in self.history_messages[::-1]:
|
||||||
|
if isinstance(prompt_message, SystemPromptMessage):
|
||||||
|
continue
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
num_prompt += 1
|
||||||
|
# a message is start with UserPromptMessage
|
||||||
|
if isinstance(prompt_message, UserPromptMessage):
|
||||||
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
|
self.memory.model_instance.model,
|
||||||
|
self.memory.model_instance.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
# if current message token is overflow, drop all the prompts in current message and break
|
||||||
|
if curr_message_tokens > max_token_limit:
|
||||||
|
prompt_messages = prompt_messages[:-num_prompt]
|
||||||
|
break
|
||||||
|
num_prompt = 0
|
||||||
|
# return prompt messages in asc order
|
||||||
|
message_prompts = prompt_messages[num_system:]
|
||||||
|
message_prompts.reverse()
|
||||||
|
|
||||||
|
# merge system and message prompt
|
||||||
|
prompt_messages = prompt_messages[:num_system]
|
||||||
|
prompt_messages.extend(message_prompts)
|
||||||
|
return prompt_messages
|
|
@ -0,0 +1,77 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
ModelConfigWithCredentialsEntity,
|
||||||
|
)
|
||||||
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
|
from models.model import Conversation
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prompt():
|
||||||
|
prompt_messages = [
|
||||||
|
SystemPromptMessage(content='System Template'),
|
||||||
|
UserPromptMessage(content='User Query'),
|
||||||
|
]
|
||||||
|
history_messages = [
|
||||||
|
SystemPromptMessage(content='System Prompt 1'),
|
||||||
|
UserPromptMessage(content='User Prompt 1'),
|
||||||
|
AssistantPromptMessage(content='Assistant Thought 1'),
|
||||||
|
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
|
||||||
|
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
|
||||||
|
SystemPromptMessage(content='System Prompt 2'),
|
||||||
|
UserPromptMessage(content='User Prompt 2'),
|
||||||
|
AssistantPromptMessage(content='Assistant Thought 2'),
|
||||||
|
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
|
||||||
|
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
|
||||||
|
UserPromptMessage(content='User Prompt 3'),
|
||||||
|
AssistantPromptMessage(content='Assistant Thought 3'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# use message number instead of token for testing
|
||||||
|
def side_effect_get_num_tokens(*args):
|
||||||
|
return len(args[2])
|
||||||
|
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||||
|
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
|
||||||
|
|
||||||
|
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||||
|
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||||
|
|
||||||
|
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||||
|
model_config_mock.model = 'openai'
|
||||||
|
model_config_mock.credentials = {}
|
||||||
|
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(
|
||||||
|
conversation=Conversation(),
|
||||||
|
model_instance=model_config_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = AgentHistoryPromptTransform(
|
||||||
|
model_config=model_config_mock,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
history_messages=history_messages,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
max_token_limit = 5
|
||||||
|
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||||
|
result = transform.get_prompt()
|
||||||
|
|
||||||
|
assert len(result) <= max_token_limit
|
||||||
|
assert len(result) == 4
|
||||||
|
|
||||||
|
max_token_limit = 20
|
||||||
|
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||||
|
result = transform.get_prompt()
|
||||||
|
|
||||||
|
assert len(result) <= max_token_limit
|
||||||
|
assert len(result) == 12
|
Loading…
Reference in New Issue
Block a user