diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 97435ad4a..4cce2ac43 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -663,11 +663,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ] return + # Check if execution was cancelled - stop execution gracefully + if self.state.status == Status.FAILED: + return + self._interrupt_state.deactivate() # Find newly ready nodes after batch execution - # We add all nodes in current batch as completed batch, - # because a failure would throw exception and code would not make it here + # Only nodes that completed successfully are considered for downstream execution newly_ready = self._find_newly_ready_nodes(current_batch) # Emit handoff event for batch transition if there are nodes to transition to @@ -868,7 +871,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) logger.debug("reason=<%s> | cancelling execution", cancel_message) yield MultiAgentNodeCancelEvent(node.node_id, cancel_message) - raise RuntimeError(cancel_message) + self.state.status = Status.FAILED + return # Build node input from satisfied dependencies node_input = self._build_node_input(node) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index ab2d86e70..0f28261c1 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2080,14 +2080,22 @@ def cancel_callback(event): stream = graph.stream_async("test task") tru_cancel_event = None - with pytest.raises(RuntimeError, match=cancel_message): - async for event in stream: - if event.get("type") == "multiagent_node_cancel": - tru_cancel_event = event - + tru_result_event = None + async for event in stream: + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event + elif event.get("type") == "multiagent_result": + tru_result_event = event + + # Verify cancel event was emitted exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message) assert tru_cancel_event == exp_cancel_event + # Verify result event was yielded (no exception raised) + assert tru_result_event is not None + assert tru_result_event["result"].status == Status.FAILED + + # Verify graph state tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status diff --git a/tests_integ/hooks/multiagent/test_cancel.py b/tests_integ/hooks/multiagent/test_cancel.py index 9267330b7..5eb547ada 100644 --- a/tests_integ/hooks/multiagent/test_cancel.py +++ b/tests_integ/hooks/multiagent/test_cancel.py @@ -73,16 +73,30 @@ async def test_swarm_cancel_node(swarm): @pytest.mark.asyncio async def test_graph_cancel_node(graph): tru_cancel_event = None - with pytest.raises(RuntimeError, match="test cancel"): - async for event in graph.stream_async("What is the weather"): - if event.get("type") == "multiagent_node_cancel": - tru_cancel_event = event + tru_result_event = None + async for event in graph.stream_async("What is the weather"): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event + elif event.get("type") == "multiagent_result": + tru_result_event = event exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel") assert tru_cancel_event == exp_cancel_event - state = graph.state + # Verify result was yielded (no exception raised) + assert tru_result_event is not None + multiagent_result = tru_result_event["result"] - tru_status = state.status + tru_status = multiagent_result.status exp_status = Status.FAILED assert tru_status == exp_status + + state = graph.state + + tru_state_status = state.status + exp_state_status = Status.FAILED + assert tru_state_status == exp_state_status + + # Verify the info node was executed but weather node was cancelled (not executed) + assert "info" in state.results + assert "weather" not in state.results