Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
source: The multi-agent orchestrator instance
node_id: ID of the node about to execute
invocation_state: Configuration that user passes in
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
node using a default cancel message.
cancel_node: A user defined message that when set, will skip the node execution and mark it as skipped,
allowing downstream nodes to continue executing. The message will be emitted under a MultiAgentNodeCancel
event. If set to `True`, Strands will skip the node using a default cancel message.
"""

source: "MultiAgentBase"
Expand Down
22 changes: 14 additions & 8 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ class Status(Enum):
PENDING: Task has not started execution yet.
EXECUTING: Task is currently running.
COMPLETED: Task finished successfully.
SKIPPED: Task was intentionally bypassed via cancel_node; downstream nodes still execute.
FAILED: Task encountered an error and could not complete.
INTERRUPTED: Task was interrupted by user.
"""

PENDING = "pending"
EXECUTING = "executing"
COMPLETED = "completed"
SKIPPED = "skipped"
FAILED = "failed"
INTERRUPTED = "interrupted"

Expand All @@ -43,8 +45,8 @@ class Status(Enum):
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""

# Core result data - single AgentResult, nested MultiAgentResult, or Exception
result: Union[AgentResult, "MultiAgentResult", Exception]
# Core result data - single AgentResult, nested MultiAgentResult, Exception, or None (skipped)
result: Union[AgentResult, "MultiAgentResult", Exception, None]

# Execution metadata
execution_time: int = 0
Expand All @@ -58,8 +60,8 @@ class NodeResult:

def get_agent_results(self) -> list[AgentResult]:
"""Get all AgentResult objects from this node, flattened if nested."""
if isinstance(self.result, Exception):
return [] # No agent results for exceptions
if self.result is None or isinstance(self.result, Exception):
return []
elif isinstance(self.result, AgentResult):
return [self.result]
else:
Expand All @@ -71,8 +73,10 @@ def get_agent_results(self) -> list[AgentResult]:

def to_dict(self) -> dict[str, Any]:
"""Convert NodeResult to JSON-serializable dict, ignoring state field."""
if isinstance(self.result, Exception):
result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)}
if self.result is None:
result_data: dict[str, Any] = {"type": "skipped"}
elif isinstance(self.result, Exception):
result_data = {"type": "exception", "message": str(self.result)}
elif isinstance(self.result, AgentResult):
result_data = self.result.to_dict()
else:
Expand All @@ -96,8 +100,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
raise TypeError("NodeResult.from_dict: missing 'result'")
raw = data["result"]

result: AgentResult | MultiAgentResult | Exception
if isinstance(raw, dict) and raw.get("type") == "agent_result":
result: AgentResult | MultiAgentResult | Exception | None
if isinstance(raw, dict) and raw.get("type") == "skipped":
result = None
elif isinstance(raw, dict) and raw.get("type") == "agent_result":
result = AgentResult.from_dict(raw)
elif isinstance(raw, dict) and raw.get("type") == "exception":
result = Exception(str(raw.get("message", "node failed")))
Expand Down
40 changes: 33 additions & 7 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class GraphState:

Attributes:
status: Current execution status of the graph.
completed_nodes: Set of nodes that have completed execution.
completed_nodes: Set of nodes whose execution is settled — either completed normally or skipped via cancel_node.
Both statuses satisfy downstream readiness checks; inspect node.execution_status to distinguish them.
failed_nodes: Set of nodes that failed during execution.
interrupted_nodes: Set of nodes that user interrupted during execution.
execution_order: List of nodes in the order they were executed.
Expand Down Expand Up @@ -132,6 +133,7 @@ class GraphResult(MultiAgentResult):

total_nodes: int = 0
completed_nodes: int = 0
skipped_nodes: int = 0
failed_nodes: int = 0
interrupted_nodes: int = 0
execution_order: list["GraphNode"] = field(default_factory=list)
Expand Down Expand Up @@ -899,9 +901,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
cancel_message = (
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
)
logger.debug("reason=<%s> | cancelling execution", cancel_message)
logger.debug("reason=<%s> | node skipped, graph continues", cancel_message)
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
raise RuntimeError(cancel_message)
node_result = NodeResult(
result=None,
execution_time=0,
status=Status.SKIPPED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=0),
execution_count=0,
)
node.result = node_result
node.execution_time = 0
node.execution_status = Status.SKIPPED
self.state.completed_nodes.add(node)
self.state.results[node.node_id] = node_result
self.state.execution_order.append(node)
self._accumulate_metrics(node_result)
yield MultiAgentNodeStopEvent(node_id=node.node_id, node_result=node_result)
return

# Build node input from satisfied dependencies
node_input = self._build_node_input(node)
Expand Down Expand Up @@ -1086,7 +1104,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:

return node_responses

# Get satisfied dependencies
# Get satisfied dependencies, excluding skipped nodes (they produced no output)
dependency_results = {}
for edge in self.edges:
if (
Expand All @@ -1095,7 +1113,9 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
and edge.from_node.node_id in self.state.results
):
if edge.should_traverse(self.state):
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]
nr = self.state.results[edge.from_node.node_id]
if nr.status != Status.SKIPPED:
dependency_results[edge.from_node.node_id] = nr

if not dependency_results:
# No dependencies - return task as ContentBlocks
Expand Down Expand Up @@ -1146,7 +1166,12 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult:
execution_count=self.state.execution_count,
execution_time=self.state.execution_time,
total_nodes=self.state.total_nodes,
completed_nodes=len(self.state.completed_nodes),
completed_nodes=sum(
1 for n in self.state.completed_nodes if n.execution_status == Status.COMPLETED
),
skipped_nodes=sum(
1 for n in self.state.completed_nodes if n.execution_status == Status.SKIPPED
),
failed_nodes=len(self.state.failed_nodes),
interrupted_nodes=len(self.state.interrupted_nodes),
execution_order=self.state.execution_order,
Expand Down Expand Up @@ -1251,7 +1276,8 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes
)
for node in self.state.completed_nodes:
node.execution_status = Status.COMPLETED
nr = results.get(node.node_id)
node.execution_status = Status.SKIPPED if (nr and nr.status == Status.SKIPPED) else Status.COMPLETED

# Execution order (only nodes that still exist)
order_node_ids = payload.get("execution_order") or []
Expand Down
74 changes: 65 additions & 9 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,20 +2113,76 @@ def cancel_callback(event):
graph = builder.build()
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)

stream = graph.stream_async("test task")

tru_cancel_event = None
with pytest.raises(RuntimeError, match=cancel_message):
async for event in stream:
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event
async for event in graph.stream_async("test task"):
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event

exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message)
assert tru_cancel_event == exp_cancel_event

tru_status = graph.state.status
exp_status = Status.FAILED
assert tru_status == exp_status
assert graph.state.status == Status.COMPLETED
assert any(n.node_id == "test_agent" for n in graph.state.completed_nodes)
assert "test_agent" in graph.state.results
assert graph.state.results["test_agent"].status == Status.SKIPPED
skipped_node = next(n for n in graph.state.completed_nodes if n.node_id == "test_agent")
assert skipped_node.execution_status == Status.SKIPPED
agent.__call__.assert_not_called()


@pytest.mark.asyncio
async def test_graph_cancel_node_downstream_executes():
"""Downstream nodes must run after an upstream node is skipped via cancel_node."""
cancelled_nodes: list[str] = []

def cancel_step_a(event):
if event.node_id == "step_a":
event.cancel_node = "step_a skipped"
return event

step_a = create_mock_agent("step_a", "Should not run")
step_b = create_mock_agent("step_b", "Step B completed")

builder = GraphBuilder()
builder.add_node(step_a, "step_a")
builder.add_node(step_b, "step_b")
builder.add_edge("step_a", "step_b")
builder.set_entry_point("step_a")
graph = builder.build()
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_step_a)

graph_result = None
async for event in graph.stream_async("test task"):
if event.get("type") == "multiagent_node_cancel":
cancelled_nodes.append(event["node_id"])
elif event.get("type") == "multiagent_result":
graph_result = event["result"]

assert cancelled_nodes == ["step_a"]
assert graph.state.status == Status.COMPLETED
step_a.__call__.assert_not_called()
step_b.stream_async.assert_called_once()

assert any(n.node_id == "step_a" for n in graph.state.completed_nodes)
assert any(n.node_id == "step_b" for n in graph.state.completed_nodes)
assert "step_a" in graph.state.results
assert "step_b" in graph.state.results

# step_a was skipped — its NodeResult must carry Status.SKIPPED, not COMPLETED
assert graph.state.results["step_a"].status == Status.SKIPPED
assert graph.state.results["step_b"].status == Status.COMPLETED
skipped_node = next(n for n in graph.state.completed_nodes if n.node_id == "step_a")
assert skipped_node.execution_status == Status.SKIPPED

# step_b must receive only the original task — no orphaned "From step_a:" header
step_b_input = step_b.stream_async.call_args.args[0]
assert len(step_b_input) == 1
assert step_b_input[0]["text"] == "test task"

# GraphResult counters must separate skipped from completed
assert graph_result is not None
assert graph_result.completed_nodes == 1 # only step_b ran
assert graph_result.skipped_nodes == 1 # only step_a was skipped


def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
Expand Down