From 97fab7649bd9022c119e3c4fa5ea1d49a64d328a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:54:21 +0800 Subject: [PATCH] feat(tests): refactor LLMNode tests for clarity Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach. --- .../core/workflow/nodes/llm/test_node.py | 62 ++++++++++--------- .../core/workflow/nodes/llm/test_scenarios.py | 20 ++++++ 2 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 99400b21b0..5c83cddfd8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -39,6 +39,7 @@ from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario class MockTokenBufferMemory: @@ -224,7 +225,6 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, - related_id="1", ) ] @@ -280,13 +280,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Test scenarios covering different file input combinations test_scenarios = [ - { - "description": "No files", - "user_query": fake_query, - "user_files": [], - "features": [], - "window_size": fake_window_size, - "prompt_template": [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -303,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -312,11 +314,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + [ UserPromptMessage(content=fake_query), ], - }, - { - "description": "User files", - "user_query": fake_query, - "user_files": [ + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ File( tenant_id="test", type=FileType.IMAGE, @@ -325,11 +327,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): remote_url=fake_remote_url, ) ], - "vision_enabled": True, - "vision_detail": fake_vision_detail, - "features": [ModelFeature.VISION], - "window_size": fake_window_size, - "prompt_template": [ + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -346,7 +348,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -360,27 +362,27 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ] ), ], - }, + ), ] for scenario in test_scenarios: - model_config.model_schema.features = scenario["features"] + model_config.model_schema.features = scenario.features # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=scenario["user_files"], + user_query=scenario.user_query, + user_files=scenario.user_files, context=fake_context, memory=memory, model_config=model_config, - prompt_template=scenario["prompt_template"], + prompt_template=scenario.prompt_template, memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, ) # Verify the result - assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" assert ( - prompt_messages == scenario["expected_messages"] - ), f"Message content mismatch in scenario: {scenario['description']}" + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 0000000000..ab5f2d620e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: list[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing")