mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(variables): replace deprecated 'get_any' with 'get' method (#9584)
This commit is contained in:
parent
5838345f48
commit
8f670f31b8
|
@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import deprecated
|
|
||||||
|
|
||||||
from core.file import File, FileAttribute, file_manager
|
from core.file import File, FileAttribute, file_manager
|
||||||
from core.variables import Segment, SegmentGroup, Variable
|
from core.variables import Segment, SegmentGroup, Variable
|
||||||
|
@ -133,26 +132,6 @@ class VariablePool(BaseModel):
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@deprecated("This method is deprecated, use `get` instead.")
|
|
||||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
|
||||||
"""
|
|
||||||
Retrieves the value from the variable pool based on the given selector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
selector (Sequence[str]): The selector used to identify the variable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The value associated with the given selector.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the selector is invalid.
|
|
||||||
"""
|
|
||||||
if len(selector) < 2:
|
|
||||||
raise ValueError("Invalid selector")
|
|
||||||
hash_key = hash(tuple(selector[1:]))
|
|
||||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
|
||||||
return value.to_object() if value else None
|
|
||||||
|
|
||||||
def remove(self, selector: Sequence[str], /):
|
def remove(self, selector: Sequence[str], /):
|
||||||
"""
|
"""
|
||||||
Remove variables from the variable pool based on the given selector.
|
Remove variables from the variable pool based on the given selector.
|
||||||
|
|
|
@ -41,10 +41,15 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||||
# Get variables
|
# Get variables
|
||||||
variables = {}
|
variables = {}
|
||||||
for variable_selector in self.node_data.variables:
|
for variable_selector in self.node_data.variables:
|
||||||
variable = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
|
if variable is None:
|
||||||
variables[variable] = value
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=f"Variable `{variable_selector.value_selector}` not found",
|
||||||
|
)
|
||||||
|
variables[variable_name] = variable.to_object()
|
||||||
# Run code
|
# Run code
|
||||||
try:
|
try:
|
||||||
result = CodeExecutor.execute_workflow_code_template(
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.variables import IntegerSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseGraphEvent,
|
BaseGraphEvent,
|
||||||
|
@ -147,9 +148,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||||
|
|
||||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
index_variable = variable_pool.get([self.node_id, "index"])
|
||||||
[self.node_id, "index"]
|
if not isinstance(index_variable, IntegerSegment):
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=f"Invalid index variable type: {type(index_variable)}",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
|
||||||
event.route_node_state.node_run_result.metadata = metadata
|
event.route_node_state.node_run_result.metadata = metadata
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
|
@ -181,7 +189,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# append to iteration output variable list
|
# append to iteration output variable list
|
||||||
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
|
||||||
|
if current_iteration_output_variable is None:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=f"Iteration output variable {self.node_data.output_selector} not found",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
current_iteration_output = current_iteration_output_variable.to_object()
|
||||||
outputs.append(current_iteration_output)
|
outputs.append(current_iteration_output)
|
||||||
|
|
||||||
# remove all nodes outputs from variable pool
|
# remove all nodes outputs from variable pool
|
||||||
|
@ -189,11 +206,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||||
variable_pool.remove([node_id])
|
variable_pool.remove([node_id])
|
||||||
|
|
||||||
# move to next iteration
|
# move to next iteration
|
||||||
current_index = variable_pool.get([self.node_id, "index"])
|
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||||
if current_index is None:
|
if not isinstance(current_index_variable, IntegerSegment):
|
||||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||||
|
|
||||||
next_index = int(current_index.to_object()) + 1
|
next_index = current_index_variable.value + 1
|
||||||
variable_pool.add([self.node_id, "index"], next_index)
|
variable_pool.add([self.node_id, "index"], next_index)
|
||||||
|
|
||||||
if next_index < len(iterator_list_value):
|
if next_index < len(iterator_list_value):
|
||||||
|
@ -205,9 +222,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||||
iteration_node_type=self.node_type,
|
iteration_node_type=self.node_type,
|
||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
index=next_index,
|
index=next_index,
|
||||||
pre_iteration_output=jsonable_encoder(current_iteration_output)
|
pre_iteration_output=jsonable_encoder(current_iteration_output),
|
||||||
if current_iteration_output
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield IterationRunSucceededEvent(
|
yield IterationRunSucceededEvent(
|
||||||
|
|
|
@ -14,6 +14,7 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.variables import StringSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
@ -39,8 +40,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||||
query = variable
|
if not isinstance(variable, StringSegment):
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
error="Query variable is not string type.",
|
||||||
|
)
|
||||||
|
query = variable.value
|
||||||
variables = {"query": query}
|
variables = {"query": query}
|
||||||
if not query:
|
if not query:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
|
|
|
@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
from core.variables import (
|
||||||
|
ArrayAnySegment,
|
||||||
|
ArrayFileSegment,
|
||||||
|
ArraySegment,
|
||||||
|
FileSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||||
variable = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
|
if variable is None:
|
||||||
|
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||||
|
|
||||||
def parse_dict(d: dict) -> str:
|
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Parse dict into string
|
Parse dict into string
|
||||||
"""
|
"""
|
||||||
# check if it's a context structure
|
# check if it's a context structure
|
||||||
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
||||||
return d["content"]
|
return input_dict["content"]
|
||||||
|
|
||||||
# else, parse the dict
|
# else, parse the dict
|
||||||
try:
|
try:
|
||||||
return json.dumps(d, ensure_ascii=False)
|
return json.dumps(input_dict, ensure_ascii=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
return str(d)
|
return str(input_dict)
|
||||||
|
|
||||||
if isinstance(value, str):
|
if isinstance(variable, ArraySegment):
|
||||||
value = value
|
|
||||||
elif isinstance(value, list):
|
|
||||||
result = ""
|
result = ""
|
||||||
for item in value:
|
for item in variable.value:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
result += parse_dict(item)
|
result += parse_dict(item)
|
||||||
elif isinstance(item, str):
|
|
||||||
result += item
|
|
||||||
elif isinstance(item, int | float):
|
|
||||||
result += str(item)
|
|
||||||
else:
|
else:
|
||||||
result += str(item)
|
result += str(item)
|
||||||
result += "\n"
|
result += "\n"
|
||||||
value = result.strip()
|
value = result.strip()
|
||||||
elif isinstance(value, dict):
|
elif isinstance(variable, ObjectSegment):
|
||||||
value = parse_dict(value)
|
value = parse_dict(variable.value)
|
||||||
elif isinstance(value, int | float):
|
|
||||||
value = str(value)
|
|
||||||
else:
|
else:
|
||||||
value = str(value)
|
value = variable.text
|
||||||
|
|
||||||
variables[variable] = value
|
variables[variable_name] = value
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
|
||||||
inputs = {}
|
inputs = {}
|
||||||
prompt_template = node_data.prompt_template
|
prompt_template = node_data.prompt_template
|
||||||
|
|
||||||
|
@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
if not node_data.context.variable_selector:
|
if not node_data.context.variable_selector:
|
||||||
return
|
return
|
||||||
|
|
||||||
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
||||||
if context_value:
|
if context_value_variable:
|
||||||
if isinstance(context_value, str):
|
if isinstance(context_value_variable, StringSegment):
|
||||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||||
elif isinstance(context_value, list):
|
elif isinstance(context_value_variable, ArraySegment):
|
||||||
context_str = ""
|
context_str = ""
|
||||||
original_retriever_resource = []
|
original_retriever_resource = []
|
||||||
for item in context_value:
|
for item in context_value_variable.value:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
context_str += item + "\n"
|
context_str += item + "\n"
|
||||||
else:
|
else:
|
||||||
|
@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get conversation id
|
# get conversation id
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||||
)
|
)
|
||||||
if conversation_id is None:
|
if not isinstance(conversation_id_variable, StringSegment):
|
||||||
return None
|
return None
|
||||||
|
conversation_id = conversation_id_variable.value
|
||||||
|
|
||||||
# get conversation
|
# get conversation
|
||||||
conversation = (
|
conversation = (
|
||||||
|
|
|
@ -33,8 +33,13 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||||
variables = {}
|
variables = {}
|
||||||
for variable_selector in self.node_data.variables:
|
for variable_selector in self.node_data.variables:
|
||||||
variable_name = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
variables[variable_name] = value
|
if value is None:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=f"Variable {variable_name} not found in variable pool",
|
||||||
|
)
|
||||||
|
variables[variable_name] = value.to_object()
|
||||||
# Run code
|
# Run code
|
||||||
try:
|
try:
|
||||||
result = CodeExecutor.execute_workflow_code_template(
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
|
|
|
@ -19,27 +19,27 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||||
|
|
||||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||||
for selector in self.node_data.variables:
|
for selector in self.node_data.variables:
|
||||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
if variable is not None:
|
if variable is not None:
|
||||||
outputs = {"output": variable}
|
outputs = {"output": variable.to_object()}
|
||||||
|
|
||||||
inputs = {".".join(selector[1:]): variable}
|
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for group in self.node_data.advanced_settings.groups:
|
for group in self.node_data.advanced_settings.groups:
|
||||||
for selector in group.variables:
|
for selector in group.variables:
|
||||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
|
|
||||||
if variable is not None:
|
if variable is not None:
|
||||||
outputs[group.group_name] = {"output": variable}
|
outputs[group.group_name] = {"output": variable.to_object()}
|
||||||
inputs[".".join(selector[1:])] = variable
|
inputs[".".join(selector[1:])] = variable.to_object()
|
||||||
break
|
break
|
||||||
|
|
||||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
|
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
"""
|
"""
|
||||||
Extract variable selector to variable mapping
|
Extract variable selector to variable mapping
|
||||||
|
|
|
@ -102,6 +102,8 @@ def test_execute_code(setup_code_executor_mock):
|
||||||
}
|
}
|
||||||
|
|
||||||
node = init_code_node(code_config)
|
node = init_code_node(code_config)
|
||||||
|
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
|
||||||
|
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
|
||||||
|
|
||||||
# execute node
|
# execute node
|
||||||
result = node._run()
|
result = node._run()
|
||||||
|
@ -146,6 +148,8 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||||
}
|
}
|
||||||
|
|
||||||
node = init_code_node(code_config)
|
node = init_code_node(code_config)
|
||||||
|
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
|
||||||
|
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
|
||||||
|
|
||||||
# execute node
|
# execute node
|
||||||
result = node._run()
|
result = node._run()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user