mirror of
https://github.com/langgenius/dify.git
synced 2024-11-15 19:22:36 +08:00
refactor(core): decouple LLMNode prompt handling
Moved prompt handling functions out of the `LLMNode` class to improve modularity and separation of concerns. This refactor allows better reuse and testing of prompt-related functions. Adjusted existing logic to fetch queries and handle context and memory configurations more effectively. Updated tests to align with the new structure and ensure continued functionality.
This commit is contained in:
parent
f68d6bd5e2
commit
4e360ec19a
|
@ -38,7 +38,6 @@ from core.variables import (
|
|||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
@ -135,10 +134,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
# fetch prompt messages
|
||||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise VariableNotFoundError("Query not found")
|
||||
query = query.text
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
else:
|
||||
query = None
|
||||
|
||||
|
@ -152,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
|
@ -550,15 +548,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context))
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = self._handle_memory_chat_mode(
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
|
@ -568,14 +576,34 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)]))
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=user_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context))
|
||||
prompt_messages.extend(
|
||||
_handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = self._handle_memory_completion_mode(
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
|
@ -628,7 +656,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features)
|
||||
and ModelFeature.VISION not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
|
@ -662,73 +690,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
self,
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = self._calculate_rest_token([], model_config)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
self,
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = self._calculate_rest_token([], model_config)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
@ -862,78 +823,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str]
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = self.graph_runtime_state.variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=self.node_data.vision.configs.detail
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=self.node_data.vision.configs.detail
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = UserPromptMessage(content=file_contents)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_completion_template(
|
||||
self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str]
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = template.text
|
||||
result_text = self.graph_runtime_state.variable_pool.convert_template(template).text
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||
match role:
|
||||
|
@ -966,3 +855,165 @@ def _render_jinja2_message(
|
|||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
||||
|
||||
def _handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = UserPromptMessage(content=file_contents)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Handle completion template processing outside of LLMNode class.
|
||||
|
||||
Args:
|
||||
template: The completion model prompt template
|
||||
context: Optional context string
|
||||
jinja2_variables: Variables for jinja2 template rendering
|
||||
variable_pool: Variable pool for template conversion
|
||||
|
||||
Returns:
|
||||
Sequence of prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template_text = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template_text = template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
|
|
@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
|
|||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
system_query=query,
|
||||
user_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
files=files,
|
||||
user_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
|
|
@ -240,6 +240,8 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
|||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||
|
@ -368,7 +370,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
description="Prompt template with variable selector of File",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
|
@ -471,6 +473,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
memory_config=memory_config,
|
||||
vision_enabled=scenario.vision_enabled,
|
||||
vision_detail=scenario.vision_detail,
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
|
|
Loading…
Reference in New Issue
Block a user