From 88c9834ef2231a90506e1eddd7eccde1976bccdb Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 14 Sep 2024 18:02:43 +0800 Subject: [PATCH] chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440) --- .../workflow/graph_engine/entities/graph.py | 16 +- .../nodes/iteration/iteration_node.py | 191 +++++++----------- 2 files changed, 80 insertions(+), 127 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 1d7e9158d8..1175f4af2a 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -689,23 +689,11 @@ class Graph(BaseModel): parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) - parallel_start_node_id = None - for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + for _, branch_node_ids in parallel_start_node_ids.items(): if set(branch_node_ids) == set(routes_node_ids.keys()): - parallel_start_node_id = p_start_node_id return True - if not parallel_start_node_id: - raise Exception("Parallel start node id not found") - - for graph_edge in reverse_edge_mapping[start_node_id]: - if ( - graph_edge.source_node_id not in all_routes_node_ids - or graph_edge.source_node_id != parallel_start_node_id - ): - return False - - return True + return False @classmethod def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 4d944e93db..6f20745daf 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.iteration.entities import IterationNodeData -from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) @@ -68,38 +66,6 @@ class IterationNode(BaseNode): if not iteration_graph: raise ValueError("iteration graph not found") - leaf_node_ids = iteration_graph.get_leaf_node_ids() - iteration_leaf_node_ids = [] - for leaf_node_id in leaf_node_ids: - node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) - if not node_config: - continue - - leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") - if not leaf_node_iteration_id: - continue - - if leaf_node_iteration_id != self.node_id: - continue - - iteration_leaf_node_ids.append(leaf_node_id) - - # add condition of end nodes to root node - iteration_graph.add_extra_edge( - source_node_id=leaf_node_id, - target_node_id=root_node_id, - run_condition=RunCondition( - type="condition", - conditions=[ - Condition( - variable_selector=[self.node_id, "index"], - comparison_operator="<", - value=str(len(iterator_list_value)), - ) - ], - ), - ) - variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool @@ -149,91 +115,90 @@ class IterationNode(BaseNode): outputs: list[Any] = [] try: - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id + for _ in range(len(iterator_list_value)): + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): - continue + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( - [self.node_id, "index"] - ) - event.route_node_state.node_run_result.metadata = metadata + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) + event.route_node_state.node_run_result.metadata = metadata - yield event - - # handle iteration run result - if event.route_node_state.node_id in iteration_leaf_node_ids: - # append to iteration output variable list - current_iteration_output = variable_pool.get_any(self.node_data.output_selector) - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove_node(node_id) - - # move to next iteration - current_index = variable_pool.get([self.node_id, "index"]) - if current_index is None: - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = int(current_index.to_object()) + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output) - if current_iteration_output - else None, - ) - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + yield event + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) - ) - break - else: - event = cast(InNodeEvent, event) - yield event + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + yield event + + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, "index"]) + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + + next_index = int(current_index.to_object()) + 1 + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, + ) yield IterationRunSucceededEvent( iteration_id=self.id,