fix(workflow): refine variable type checks in LLMNode (#10051)

This commit is contained in:
-LAN- 2024-10-30 16:23:12 +08:00 committed by GitHub
parent 4d38798dd5
commit 3b53e06e0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 5 deletions

View File

@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]):
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
if isinstance(variable, FileSegment):
elif isinstance(variable, FileSegment):
return [variable.value]
if isinstance(variable, ArrayFileSegment):
elif isinstance(variable, ArrayFileSegment):
return variable.value
# FIXME: Temporary fix for empty array,
# all variables added to variable pool should be a Segment instance.
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise ValueError(f"Invalid variable type: {type(variable)}")

View File

@ -0,0 +1,125 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowType
class TestLLMNode:
@pytest.fixture
def llm_node(self):
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
def test_fetch_files_with_file_segment(self, llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(self, llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(self, llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []