refactor(node.py): streamline template rendering

Removed the `_render_basic_message` function and integrated its logic directly into the `LLMNode` class. This reduces redundancy and simplifies the handling of message templates by utilizing `convert_template` more directly. This change enhances code readability and maintainability.
This commit is contained in:
-LAN- 2024-11-14 23:35:20 +08:00
parent b860a893c8
commit f68d6bd5e2

View File

@ -36,7 +36,6 @@ from core.variables import (
FileSegment, FileSegment,
NoneSegment, NoneSegment,
ObjectSegment, ObjectSegment,
SegmentGroup,
StringSegment, StringSegment,
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
@ -878,11 +877,11 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
else: else:
# Get segment group from basic message # Get segment group from basic message
segment_group = _render_basic_message( if context:
template=message.text, template = message.text.replace("{#context#}", context)
context=context, else:
variable_pool=self.graph_runtime_state.variable_pool, template = message.text
) segment_group = self.graph_runtime_state.variable_pool.convert_template(template)
# Process segments for images # Process segments for images
file_contents = [] file_contents = []
@ -926,11 +925,11 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
) )
else: else:
result_text = _render_basic_message( if context:
template=template.text, template = template.text.replace("{#context#}", context)
context=context, else:
variable_pool=self.graph_runtime_state.variable_pool, template = template.text
).text result_text = self.graph_runtime_state.variable_pool.convert_template(template).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
return prompt_messages return prompt_messages
@ -967,18 +966,3 @@ def _render_jinja2_message(
) )
result_text = code_execute_resp["result"] result_text = code_execute_resp["result"]
return result_text return result_text
def _render_basic_message(
*,
template: str,
context: str | None,
variable_pool: VariablePool,
) -> SegmentGroup:
if not template:
return SegmentGroup(value=[])
if context:
template = template.replace("{#context#}", context)
return variable_pool.convert_template(template)