Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
"status": "error",
"content": [{"text": f"Error: {error_msg}"}],
},
exception=e,
)
except Exception as e:
# Return error result with exception details for any other error
Expand All @@ -632,14 +633,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
"status": "error",
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
},
exception=e,
)

def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent:
def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | None = None) -> ToolResultEvent:
# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
# Result is already in the expected format, just add toolUseId
result["toolUseId"] = tool_use_d
return ToolResultEvent(cast(ToolResult, result))
return ToolResultEvent(cast(ToolResult, result), exception=exception)
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
Expand All @@ -648,7 +650,8 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent:
"toolUseId": tool_use_d,
"status": "success",
"content": [{"text": str(result)}],
}
},
exception=exception,
)

@property
Expand Down
7 changes: 6 additions & 1 deletion src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ async def _stream(
return
if structured_output_context.is_enabled:
kwargs["structured_output_context"] = structured_output_context

exception: Exception | None = None

async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
# Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream()
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
Expand All @@ -227,6 +230,8 @@ async def _stream(
return

if isinstance(event, ToolResultEvent):
# Preserve exception from decorated tools before extracting tool_result
exception = event.exception
# below the last "event" must point to the tool_result
event = event.tool_result
break
Expand All @@ -239,7 +244,7 @@ async def _stream(
result = cast(ToolResult, event)

after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
agent, selected_tool, tool_use, invocation_state, result
agent, selected_tool, tool_use, invocation_state, result, exception=exception
)

# Check if retry requested (getattr for BidiAfterToolCallEvent compatibility)
Expand Down
15 changes: 10 additions & 5 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,18 @@ def prepare(self, invocation_state: dict) -> None:
class ToolResultEvent(TypedEvent):
"""Event emitted when a tool execution completes."""

def __init__(self, tool_result: ToolResult) -> None:
"""Initialize with the completed tool result.
def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None:
"""Initialize tool result event."""
super().__init__({"type": "tool_result", "tool_result": tool_result})
self._exception = exception

Args:
tool_result: Final result from the tool execution
@property
def exception(self) -> Exception | None:
"""The original exception that occurred, if any.

Can be used for re-raising or type-based error handling.
"""
super().__init__({"type": "tool_result", "tool_result": tool_result})
return self._exception

@property
def tool_use_id(self) -> str:
Expand Down
92 changes: 92 additions & 0 deletions tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,98 @@ async def test_executor_stream_updates_invocation_state_with_agent(
assert empty_invocation_state["agent"] is agent


@pytest.mark.asyncio
async def test_executor_stream_decorated_tool_exception_in_hook(
executor, agent, tool_results, invocation_state, hook_events, alist
):
"""Test that exceptions from @tool-decorated functions reach AfterToolCallEvent."""
exception = ValueError("decorated tool error")

@strands.tool(name="decorated_error_tool")
def failing_tool():
"""A tool that raises an exception."""
raise exception

agent.tool_registry.register_tool(failing_tool)
tool_use = {"name": "decorated_error_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)
await alist(stream)

after_event = hook_events[-1]
assert isinstance(after_event, AfterToolCallEvent)
assert after_event.exception is exception


@pytest.mark.asyncio
async def test_executor_stream_decorated_tool_runtime_error_in_hook(
executor, agent, tool_results, invocation_state, hook_events, alist
):
"""Test that RuntimeError from @tool-decorated functions reach AfterToolCallEvent."""
exception = RuntimeError("runtime error from decorated tool")

@strands.tool(name="runtime_error_tool")
def runtime_error_tool():
"""A tool that raises a RuntimeError."""
raise exception

agent.tool_registry.register_tool(runtime_error_tool)
tool_use = {"name": "runtime_error_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)
await alist(stream)

after_event = hook_events[-1]
assert isinstance(after_event, AfterToolCallEvent)
assert after_event.exception is exception


@pytest.mark.asyncio
async def test_executor_stream_decorated_tool_no_exception_on_success(
executor, agent, tool_results, invocation_state, hook_events, alist
):
"""Test that AfterToolCallEvent.exception is None when decorated tool succeeds."""

@strands.tool(name="success_decorated_tool")
def success_tool():
"""A tool that succeeds."""
return "success"

agent.tool_registry.register_tool(success_tool)
tool_use = {"name": "success_decorated_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)
await alist(stream)

after_event = hook_events[-1]
assert isinstance(after_event, AfterToolCallEvent)
assert after_event.exception is None
assert after_event.result["status"] == "success"


@pytest.mark.asyncio
async def test_executor_stream_decorated_tool_error_result_without_exception(
executor, agent, tool_results, invocation_state, hook_events, alist
):
"""Test that exception is None when a tool returns an error result without throwing."""

@strands.tool(name="error_result_tool")
def error_result_tool():
"""A tool that returns an error result dict without raising."""
return {"status": "error", "content": [{"text": "something went wrong"}]}

agent.tool_registry.register_tool(error_result_tool)
tool_use = {"name": "error_result_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)
await alist(stream)

after_event = hook_events[-1]
assert isinstance(after_event, AfterToolCallEvent)
assert after_event.exception is None
assert after_event.result["status"] == "error"


@pytest.mark.asyncio
async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist):
"""Test default behavior when retry is not set - tool executes once."""
Expand Down
77 changes: 77 additions & 0 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,6 +1825,83 @@ def inner_default_tool(name: str, level: Annotated[int, Field(description="A lev
return f"{name} is at level {level}"


@pytest.mark.asyncio
async def test_tool_result_event_carries_exception_runtime_error(alist):
"""Test that ToolResultEvent carries exception when tool raises RuntimeError."""

@strands.tool
def error_tool():
"""Tool that raises a RuntimeError."""
raise RuntimeError("test runtime error")

tool_use = {"toolUseId": "test-id", "input": {}}
events = await alist(error_tool.stream(tool_use, {}))

result_event = events[-1]
assert isinstance(result_event, ToolResultEvent)
assert hasattr(result_event, "exception")
assert isinstance(result_event.exception, RuntimeError)
assert str(result_event.exception) == "test runtime error"
assert result_event.tool_result["status"] == "error"


@pytest.mark.asyncio
async def test_tool_result_event_carries_exception_value_error(alist):
"""Test that ToolResultEvent carries exception when tool raises ValueError."""

@strands.tool
def validation_error_tool():
"""Tool that raises a ValueError."""
raise ValueError("validation failed")

tool_use = {"toolUseId": "test-id", "input": {}}
events = await alist(validation_error_tool.stream(tool_use, {}))

result_event = events[-1]
assert isinstance(result_event, ToolResultEvent)
assert hasattr(result_event, "exception")
assert isinstance(result_event.exception, ValueError)
assert str(result_event.exception) == "validation failed"
assert result_event.tool_result["status"] == "error"


@pytest.mark.asyncio
async def test_tool_result_event_no_exception_on_success(alist):
"""Test that ToolResultEvent.exception is None when tool succeeds."""

@strands.tool
def success_tool():
"""Tool that succeeds."""
return "success"

tool_use = {"toolUseId": "test-id", "input": {}}
events = await alist(success_tool.stream(tool_use, {}))

result_event = events[-1]
assert isinstance(result_event, ToolResultEvent)
assert result_event.exception is None
assert result_event.tool_result["status"] == "success"


@pytest.mark.asyncio
async def test_tool_result_event_carries_exception_assertion_error(alist):
"""Test that ToolResultEvent carries AssertionError for unexpected failures."""

@strands.tool
def assertion_error_tool():
"""Tool that raises an AssertionError."""
raise AssertionError("unexpected assertion failure")

tool_use = {"toolUseId": "test-id", "input": {}}
events = await alist(assertion_error_tool.stream(tool_use, {}))

result_event = events[-1]
assert isinstance(result_event, ToolResultEvent)
assert isinstance(result_event.exception, AssertionError)
assert "unexpected assertion failure" in str(result_event.exception)
assert result_event.tool_result["status"] == "error"


def test_tool_nullable_required_field_preserves_anyof():
"""Test that a required nullable field preserves anyOf so the model can pass null.

Expand Down
Loading