diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index efd8ace653..1e4f89480e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -868,8 +868,10 @@ class LLMNode(BaseNode[LLMNodeData]): image_contents.append(image_content) # Create message with text from all segments - prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role) - prompt_messages.append(prompt_message) + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) if image_contents: # Create message with image contents diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5c83cddfd8..0b78d81c89 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -363,11 +363,49 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), ], ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: model_config.model_schema.features = scenario.features + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( user_query=scenario.user_query, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index ab5f2d620e..8e39445baf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping, Sequence + from pydantic import BaseModel, Field from core.file import File @@ -11,10 +13,13 @@ class LLMNodeTestScenario(BaseModel): description: str = Field(..., description="Description of the test scenario") user_query: str = Field(..., description="User query input") - user_files: list[File] = Field(default_factory=list, description="List of user files") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") window_size: int = Field(..., description="Window size for memory") - prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")