chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)

This commit is contained in:
takatost 2024-09-14 18:02:43 +08:00 committed by GitHub
parent d882348f39
commit 88c9834ef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 127 deletions

View File

@ -689,24 +689,12 @@ class Graph(BaseModel):
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None for _, branch_node_ids in parallel_start_node_ids.items():
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()): if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if (
graph_edge.source_node_id not in all_routes_node_ids
or graph_edge.source_node_id != parallel_start_node_id
):
return False return False
return True
@classmethod @classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
""" """

View File

@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,38 +66,6 @@ class IterationNode(BaseNode):
if not iteration_graph: if not iteration_graph:
raise ValueError("iteration graph not found") raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value)),
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool # append iteration variable (item, index) to variable pool
@ -149,6 +115,7 @@ class IterationNode(BaseNode):
outputs: list[Any] = [] outputs: list[Any] = []
try: try:
for _ in range(len(iterator_list_value)):
# run workflow # run workflow
rst = graph_engine.run() rst = graph_engine.run()
for event in rst: for event in rst:
@ -176,9 +143,33 @@ class IterationNode(BaseNode):
event.route_node_state.node_run_result.metadata = metadata event.route_node_state.node_run_result.metadata = metadata
yield event yield event
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# handle iteration run result
if event.route_node_state.node_id in iteration_leaf_node_ids:
# append to iteration output variable list # append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector) current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output) outputs.append(current_iteration_output)
@ -208,32 +199,6 @@ class IterationNode(BaseNode):
if current_iteration_output if current_iteration_output
else None, else None,
) )
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
break
else:
event = cast(InNodeEvent, event)
yield event
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,