chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)

This commit is contained in:
takatost 2024-09-14 18:02:43 +08:00 committed by GitHub
parent d882348f39
commit 88c9834ef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 127 deletions

View File

@ -689,23 +689,11 @@ class Graph(BaseModel):
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None for _, branch_node_ids in parallel_start_node_ids.items():
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()): if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True return True
if not parallel_start_node_id: return False
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if (
graph_edge.source_node_id not in all_routes_node_ids
or graph_edge.source_node_id != parallel_start_node_id
):
return False
return True
@classmethod @classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:

View File

@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,38 +66,6 @@ class IterationNode(BaseNode):
if not iteration_graph: if not iteration_graph:
raise ValueError("iteration graph not found") raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value)),
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool # append iteration variable (item, index) to variable pool
@ -149,91 +115,90 @@ class IterationNode(BaseNode):
outputs: list[Any] = [] outputs: list[Any] = []
try: try:
# run workflow for _ in range(len(iterator_list_value)):
rst = graph_engine.run() # run workflow
for event in rst: rst = graph_engine.run()
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: for event in rst:
event.in_iteration_id = self.node_id if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if ( if (
isinstance(event, BaseNodeEvent) isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent) and not isinstance(event, NodeRunStreamChunkEvent)
): ):
continue continue
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result: if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata metadata = event.route_node_state.node_run_result.metadata
if not metadata: if not metadata:
metadata = {} metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata: if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
[self.node_id, "index"] [self.node_id, "index"]
) )
event.route_node_state.node_run_result.metadata = metadata event.route_node_state.node_run_result.metadata = metadata
yield event yield event
elif isinstance(event, BaseGraphEvent):
# handle iteration run result if isinstance(event, GraphRunFailedEvent):
if event.route_node_state.node_id in iteration_leaf_node_ids: # iteration run failed
# append to iteration output variable list yield IterationRunFailedEvent(
current_iteration_output = variable_pool.get_any(self.node_data.output_selector) iteration_id=self.id,
outputs.append(current_iteration_output) iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
# remove all nodes outputs from variable pool iteration_node_data=self.node_data,
for node_id in iteration_graph.node_ids: start_at=start_at,
variable_pool.remove_node(node_id) inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
# move to next iteration steps=len(iterator_list_value),
current_index = variable_pool.get([self.node_id, "index"]) metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error, error=event.error,
) )
)
break yield RunCompletedEvent(
else: run_result=NodeRunResult(
event = cast(InNodeEvent, event) status=WorkflowNodeExecutionStatus.FAILED,
yield event error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,