mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)
This commit is contained in:
parent
d882348f39
commit
88c9834ef2
|
@ -689,23 +689,11 @@ class Graph(BaseModel):
|
||||||
|
|
||||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||||
|
|
||||||
parallel_start_node_id = None
|
for _, branch_node_ids in parallel_start_node_ids.items():
|
||||||
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
|
||||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||||
parallel_start_node_id = p_start_node_id
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not parallel_start_node_id:
|
return False
|
||||||
raise Exception("Parallel start node id not found")
|
|
||||||
|
|
||||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
|
||||||
if (
|
|
||||||
graph_edge.source_node_id not in all_routes_node_ids
|
|
||||||
or graph_edge.source_node_id != parallel_start_node_id
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
||||||
|
|
|
@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||||
from core.workflow.utils.condition.entities import Condition
|
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -68,38 +66,6 @@ class IterationNode(BaseNode):
|
||||||
if not iteration_graph:
|
if not iteration_graph:
|
||||||
raise ValueError("iteration graph not found")
|
raise ValueError("iteration graph not found")
|
||||||
|
|
||||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
|
||||||
iteration_leaf_node_ids = []
|
|
||||||
for leaf_node_id in leaf_node_ids:
|
|
||||||
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
|
|
||||||
if not node_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
|
|
||||||
if not leaf_node_iteration_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if leaf_node_iteration_id != self.node_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
iteration_leaf_node_ids.append(leaf_node_id)
|
|
||||||
|
|
||||||
# add condition of end nodes to root node
|
|
||||||
iteration_graph.add_extra_edge(
|
|
||||||
source_node_id=leaf_node_id,
|
|
||||||
target_node_id=root_node_id,
|
|
||||||
run_condition=RunCondition(
|
|
||||||
type="condition",
|
|
||||||
conditions=[
|
|
||||||
Condition(
|
|
||||||
variable_selector=[self.node_id, "index"],
|
|
||||||
comparison_operator="<",
|
|
||||||
value=str(len(iterator_list_value)),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
# append iteration variable (item, index) to variable pool
|
# append iteration variable (item, index) to variable pool
|
||||||
|
@ -149,91 +115,90 @@ class IterationNode(BaseNode):
|
||||||
|
|
||||||
outputs: list[Any] = []
|
outputs: list[Any] = []
|
||||||
try:
|
try:
|
||||||
# run workflow
|
for _ in range(len(iterator_list_value)):
|
||||||
rst = graph_engine.run()
|
# run workflow
|
||||||
for event in rst:
|
rst = graph_engine.run()
|
||||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
for event in rst:
|
||||||
event.in_iteration_id = self.node_id
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||||
|
event.in_iteration_id = self.node_id
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(event, BaseNodeEvent)
|
isinstance(event, BaseNodeEvent)
|
||||||
and event.node_type == NodeType.ITERATION_START
|
and event.node_type == NodeType.ITERATION_START
|
||||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
if event.route_node_state.node_run_result:
|
if event.route_node_state.node_run_result:
|
||||||
metadata = event.route_node_state.node_run_result.metadata
|
metadata = event.route_node_state.node_run_result.metadata
|
||||||
if not metadata:
|
if not metadata:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
||||||
[self.node_id, "index"]
|
[self.node_id, "index"]
|
||||||
)
|
)
|
||||||
event.route_node_state.node_run_result.metadata = metadata
|
event.route_node_state.node_run_result.metadata = metadata
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
|
elif isinstance(event, BaseGraphEvent):
|
||||||
# handle iteration run result
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
if event.route_node_state.node_id in iteration_leaf_node_ids:
|
# iteration run failed
|
||||||
# append to iteration output variable list
|
yield IterationRunFailedEvent(
|
||||||
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
iteration_id=self.id,
|
||||||
outputs.append(current_iteration_output)
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
# remove all nodes outputs from variable pool
|
iteration_node_data=self.node_data,
|
||||||
for node_id in iteration_graph.node_ids:
|
start_at=start_at,
|
||||||
variable_pool.remove_node(node_id)
|
inputs=inputs,
|
||||||
|
outputs={"output": jsonable_encoder(outputs)},
|
||||||
# move to next iteration
|
steps=len(iterator_list_value),
|
||||||
current_index = variable_pool.get([self.node_id, "index"])
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||||
if current_index is None:
|
|
||||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
|
||||||
|
|
||||||
next_index = int(current_index.to_object()) + 1
|
|
||||||
variable_pool.add([self.node_id, "index"], next_index)
|
|
||||||
|
|
||||||
if next_index < len(iterator_list_value):
|
|
||||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
||||||
|
|
||||||
yield IterationRunNextEvent(
|
|
||||||
iteration_id=self.id,
|
|
||||||
iteration_node_id=self.node_id,
|
|
||||||
iteration_node_type=self.node_type,
|
|
||||||
iteration_node_data=self.node_data,
|
|
||||||
index=next_index,
|
|
||||||
pre_iteration_output=jsonable_encoder(current_iteration_output)
|
|
||||||
if current_iteration_output
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
elif isinstance(event, BaseGraphEvent):
|
|
||||||
if isinstance(event, GraphRunFailedEvent):
|
|
||||||
# iteration run failed
|
|
||||||
yield IterationRunFailedEvent(
|
|
||||||
iteration_id=self.id,
|
|
||||||
iteration_node_id=self.node_id,
|
|
||||||
iteration_node_type=self.node_type,
|
|
||||||
iteration_node_data=self.node_data,
|
|
||||||
start_at=start_at,
|
|
||||||
inputs=inputs,
|
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
|
||||||
steps=len(iterator_list_value),
|
|
||||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
|
||||||
run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=event.error,
|
error=event.error,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
break
|
yield RunCompletedEvent(
|
||||||
else:
|
run_result=NodeRunResult(
|
||||||
event = cast(InNodeEvent, event)
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
yield event
|
error=event.error,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
event = cast(InNodeEvent, event)
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# append to iteration output variable list
|
||||||
|
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
||||||
|
outputs.append(current_iteration_output)
|
||||||
|
|
||||||
|
# remove all nodes outputs from variable pool
|
||||||
|
for node_id in iteration_graph.node_ids:
|
||||||
|
variable_pool.remove_node(node_id)
|
||||||
|
|
||||||
|
# move to next iteration
|
||||||
|
current_index = variable_pool.get([self.node_id, "index"])
|
||||||
|
if current_index is None:
|
||||||
|
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||||
|
|
||||||
|
next_index = int(current_index.to_object()) + 1
|
||||||
|
variable_pool.add([self.node_id, "index"], next_index)
|
||||||
|
|
||||||
|
if next_index < len(iterator_list_value):
|
||||||
|
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||||
|
|
||||||
|
yield IterationRunNextEvent(
|
||||||
|
iteration_id=self.id,
|
||||||
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
|
index=next_index,
|
||||||
|
pre_iteration_output=jsonable_encoder(current_iteration_output)
|
||||||
|
if current_iteration_output
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
yield IterationRunSucceededEvent(
|
yield IterationRunSucceededEvent(
|
||||||
iteration_id=self.id,
|
iteration_id=self.id,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user