diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 80b50770a1..dd9a682626 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -330,9 +330,9 @@ class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): source: The multi-agent orchestrator instance node_id: ID of the node about to execute invocation_state: Configuration that user passes in - cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. - The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the - node using a default cancel message. + cancel_node: A user defined message that when set, will skip the node execution and mark it as skipped, + allowing downstream nodes to continue executing. The message will be emitted under a MultiAgentNodeCancel + event. If set to `True`, Strands will skip the node using a default cancel message. """ source: "MultiAgentBase" diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index dc3258f688..051bd143a1 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -28,6 +28,7 @@ class Status(Enum): PENDING: Task has not started execution yet. EXECUTING: Task is currently running. COMPLETED: Task finished successfully. + SKIPPED: Task was intentionally bypassed via cancel_node; downstream nodes still execute. FAILED: Task encountered an error and could not complete. INTERRUPTED: Task was interrupted by user. """ @@ -35,6 +36,7 @@ class Status(Enum): PENDING = "pending" EXECUTING = "executing" COMPLETED = "completed" + SKIPPED = "skipped" FAILED = "failed" INTERRUPTED = "interrupted" @@ -43,8 +45,8 @@ class Status(Enum): class NodeResult: """Unified result from node execution - handles both Agent and nested MultiAgentBase results.""" - # Core result data - single AgentResult, nested MultiAgentResult, or Exception - result: Union[AgentResult, "MultiAgentResult", Exception] + # Core result data - single AgentResult, nested MultiAgentResult, Exception, or None (skipped) + result: Union[AgentResult, "MultiAgentResult", Exception, None] # Execution metadata execution_time: int = 0 @@ -58,8 +60,8 @@ class NodeResult: def get_agent_results(self) -> list[AgentResult]: """Get all AgentResult objects from this node, flattened if nested.""" - if isinstance(self.result, Exception): - return [] # No agent results for exceptions + if self.result is None or isinstance(self.result, Exception): + return [] elif isinstance(self.result, AgentResult): return [self.result] else: @@ -71,8 +73,10 @@ def get_agent_results(self) -> list[AgentResult]: def to_dict(self) -> dict[str, Any]: """Convert NodeResult to JSON-serializable dict, ignoring state field.""" - if isinstance(self.result, Exception): - result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} + if self.result is None: + result_data: dict[str, Any] = {"type": "skipped"} + elif isinstance(self.result, Exception): + result_data = {"type": "exception", "message": str(self.result)} elif isinstance(self.result, AgentResult): result_data = self.result.to_dict() else: @@ -96,8 +100,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": raise TypeError("NodeResult.from_dict: missing 'result'") raw = data["result"] - result: AgentResult | MultiAgentResult | Exception - if isinstance(raw, dict) and raw.get("type") == "agent_result": + result: AgentResult | MultiAgentResult | Exception | None + if isinstance(raw, dict) and raw.get("type") == "skipped": + result = None + elif isinstance(raw, dict) and raw.get("type") == "agent_result": result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": result = Exception(str(raw.get("message", "node failed"))) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8da8314eab..50f6d78992 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -66,7 +66,8 @@ class GraphState: Attributes: status: Current execution status of the graph. - completed_nodes: Set of nodes that have completed execution. + completed_nodes: Set of nodes whose execution is settled — either completed normally or skipped via cancel_node. + Both statuses satisfy downstream readiness checks; inspect node.execution_status to distinguish them. failed_nodes: Set of nodes that failed during execution. interrupted_nodes: Set of nodes that user interrupted during execution. execution_order: List of nodes in the order they were executed. @@ -132,6 +133,7 @@ class GraphResult(MultiAgentResult): total_nodes: int = 0 completed_nodes: int = 0 + skipped_nodes: int = 0 failed_nodes: int = 0 interrupted_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) @@ -899,9 +901,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" ) - logger.debug("reason=<%s> | cancelling execution", cancel_message) + logger.debug("reason=<%s> | node skipped, graph continues", cancel_message) yield MultiAgentNodeCancelEvent(node.node_id, cancel_message) - raise RuntimeError(cancel_message) + node_result = NodeResult( + result=None, + execution_time=0, + status=Status.SKIPPED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=0), + execution_count=0, + ) + node.result = node_result + node.execution_time = 0 + node.execution_status = Status.SKIPPED + self.state.completed_nodes.add(node) + self.state.results[node.node_id] = node_result + self.state.execution_order.append(node) + self._accumulate_metrics(node_result) + yield MultiAgentNodeStopEvent(node_id=node.node_id, node_result=node_result) + return # Build node input from satisfied dependencies node_input = self._build_node_input(node) @@ -1086,7 +1104,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_responses - # Get satisfied dependencies + # Get satisfied dependencies, excluding skipped nodes (they produced no output) dependency_results = {} for edge in self.edges: if ( @@ -1095,7 +1113,9 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: and edge.from_node.node_id in self.state.results ): if edge.should_traverse(self.state): - dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] + nr = self.state.results[edge.from_node.node_id] + if nr.status != Status.SKIPPED: + dependency_results[edge.from_node.node_id] = nr if not dependency_results: # No dependencies - return task as ContentBlocks @@ -1146,7 +1166,12 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: execution_count=self.state.execution_count, execution_time=self.state.execution_time, total_nodes=self.state.total_nodes, - completed_nodes=len(self.state.completed_nodes), + completed_nodes=sum( + 1 for n in self.state.completed_nodes if n.execution_status == Status.COMPLETED + ), + skipped_nodes=sum( + 1 for n in self.state.completed_nodes if n.execution_status == Status.SKIPPED + ), failed_nodes=len(self.state.failed_nodes), interrupted_nodes=len(self.state.interrupted_nodes), execution_order=self.state.execution_order, @@ -1251,7 +1276,8 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes ) for node in self.state.completed_nodes: - node.execution_status = Status.COMPLETED + nr = results.get(node.node_id) + node.execution_status = Status.SKIPPED if (nr and nr.status == Status.SKIPPED) else Status.COMPLETED # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index a6085627c3..bb74c308da 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2113,20 +2113,76 @@ def cancel_callback(event): graph = builder.build() graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback) - stream = graph.stream_async("test task") - tru_cancel_event = None - with pytest.raises(RuntimeError, match=cancel_message): - async for event in stream: - if event.get("type") == "multiagent_node_cancel": - tru_cancel_event = event + async for event in graph.stream_async("test task"): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message) assert tru_cancel_event == exp_cancel_event - tru_status = graph.state.status - exp_status = Status.FAILED - assert tru_status == exp_status + assert graph.state.status == Status.COMPLETED + assert any(n.node_id == "test_agent" for n in graph.state.completed_nodes) + assert "test_agent" in graph.state.results + assert graph.state.results["test_agent"].status == Status.SKIPPED + skipped_node = next(n for n in graph.state.completed_nodes if n.node_id == "test_agent") + assert skipped_node.execution_status == Status.SKIPPED + agent.__call__.assert_not_called() + + +@pytest.mark.asyncio +async def test_graph_cancel_node_downstream_executes(): + """Downstream nodes must run after an upstream node is skipped via cancel_node.""" + cancelled_nodes: list[str] = [] + + def cancel_step_a(event): + if event.node_id == "step_a": + event.cancel_node = "step_a skipped" + return event + + step_a = create_mock_agent("step_a", "Should not run") + step_b = create_mock_agent("step_b", "Step B completed") + + builder = GraphBuilder() + builder.add_node(step_a, "step_a") + builder.add_node(step_b, "step_b") + builder.add_edge("step_a", "step_b") + builder.set_entry_point("step_a") + graph = builder.build() + graph.hooks.add_callback(BeforeNodeCallEvent, cancel_step_a) + + graph_result = None + async for event in graph.stream_async("test task"): + if event.get("type") == "multiagent_node_cancel": + cancelled_nodes.append(event["node_id"]) + elif event.get("type") == "multiagent_result": + graph_result = event["result"] + + assert cancelled_nodes == ["step_a"] + assert graph.state.status == Status.COMPLETED + step_a.__call__.assert_not_called() + step_b.stream_async.assert_called_once() + + assert any(n.node_id == "step_a" for n in graph.state.completed_nodes) + assert any(n.node_id == "step_b" for n in graph.state.completed_nodes) + assert "step_a" in graph.state.results + assert "step_b" in graph.state.results + + # step_a was skipped — its NodeResult must carry Status.SKIPPED, not COMPLETED + assert graph.state.results["step_a"].status == Status.SKIPPED + assert graph.state.results["step_b"].status == Status.COMPLETED + skipped_node = next(n for n in graph.state.completed_nodes if n.node_id == "step_a") + assert skipped_node.execution_status == Status.SKIPPED + + # step_b must receive only the original task — no orphaned "From step_a:" header + step_b_input = step_b.stream_async.call_args.args[0] + assert len(step_b_input) == 1 + assert step_b_input[0]["text"] == "test task" + + # GraphResult counters must separate skipped from completed + assert graph_result is not None + assert graph_result.completed_nodes == 1 # only step_b ran + assert graph_result.skipped_nodes == 1 # only step_a was skipped def test_graph_interrupt_on_before_node_call_event(interrupt_hook):