From ac9577b9f42f1fe777ae94b00d7e3e7ffd514f70 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 13 Jan 2026 12:53:52 -0800 Subject: [PATCH 01/10] feat(mcp): Implement basic support for Tasks Implements client support for MCP Tasks via an adapter model that handles both tool execution modes identically. This can later be hooked into event handlers in a more intelligent way, but this unblocks support for simply invoking task-augmented tools. Keep error handling and edge case tests (timeout, failure status, config). Also remove unused create_tool_with_task_support helper and trim task_echo_server. Reduces PR diff from 1433 to 969 lines (under 1000 limit). --- src/strands/tools/mcp/mcp_client.py | 310 ++++++++++++++++-- tests/strands/tools/mcp/conftest.py | 59 ++++ tests/strands/tools/mcp/test_mcp_client.py | 32 +- .../tools/mcp/test_mcp_client_tasks.py | 223 +++++++++++++ tests_integ/mcp/task_echo_server.py | 139 ++++++++ tests_integ/mcp/test_mcp_client_tasks.py | 188 +++++++++++ 6 files changed, 900 insertions(+), 51 deletions(-) create mode 100644 tests/strands/tools/mcp/conftest.py create mode 100644 tests/strands/tools/mcp/test_mcp_client_tasks.py create mode 100644 tests_integ/mcp/task_echo_server.py create mode 100644 tests_integ/mcp/test_mcp_client_tasks.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1aff22a1e..2ca284407 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -120,6 +120,8 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + default_task_ttl_ms: int = 60000, + default_task_poll_timeout_seconds: float = 300.0, ) -> None: """Initialize a new MCP Server connection. @@ -130,6 +132,10 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + default_task_ttl_ms: Default time-to-live in milliseconds for task-augmented tool calls. + Defaults to 60000 (1 minute). + default_task_poll_timeout_seconds: Default timeout in seconds for polling task completion. + Defaults to 300.0 (5 minutes). """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -154,6 +160,12 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() + # Task support caching + self._default_task_ttl_ms = default_task_ttl_ms + self._default_task_poll_timeout_seconds = default_task_poll_timeout_seconds + self._server_task_capable: bool | None = None + self._tool_task_support_cache: dict[str, str | None] = {} + def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -358,6 +370,8 @@ async def _set_close_event() -> None: self._loaded_tools = None self._tool_provider_started = False self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} if self._close_exception: exception = self._close_exception @@ -392,14 +406,37 @@ def list_tools_sync( effective_prefix = self._prefix if prefix is None else prefix effective_filters = self._tool_filters if tool_filters is None else tool_filters - async def _list_tools_async() -> ListToolsResult: - return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) + async def _list_tools_and_cache_capabilities_async() -> ListToolsResult: + session = cast(ClientSession, self._background_thread_session) + list_tools_result = await session.list_tools(cursor=pagination_token) + + # Cache server task capability while we have an active session + # This avoids needing a separate async call later during call_tool_* + if self._server_task_capable is None: + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + + return list_tools_result - list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() + list_tools_response: ListToolsResult = self._invoke_on_background_thread( + _list_tools_and_cache_capabilities_async() + ).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) mcp_tools = [] for tool in list_tools_response.tools: + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support + # Apply prefix if specified if effective_prefix: prefixed_name = f"{effective_prefix}_{tool.name}" @@ -539,6 +576,45 @@ async def _list_resource_templates_async() -> ListResourceTemplatesResult: return list_resource_templates_result + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + poll_timeout = self._convert_timeout_for_polling(read_timeout_seconds) + + async def _call_as_task() -> MCPCallToolResult: + return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout_seconds=poll_timeout) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) + + return _call_tool_direct() + def call_tool_sync( self, tool_use_id: str, @@ -548,10 +624,8 @@ def call_tool_sync( ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. If the MCP tool returns - structured content, it will be included as the last item in the content array - of the returned ToolResult. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -566,13 +640,9 @@ def call_tool_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: logger.exception("tool execution failed") @@ -587,8 +657,8 @@ async def call_tool_async( ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the MCPToolResult format. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -603,13 +673,9 @@ async def call_tool_async( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - future = self._invoke_on_background_thread(_call_tool_async()) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -898,3 +964,205 @@ def _is_session_active(self) -> bool: return False return True + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the cached capability value that was populated during list_tools_sync(). + If list_tools_sync() hasn't been called yet, returns False (conservative default). + + The capability is cached during list_tools_sync() to avoid needing a separate + async call during call_tool_*() operations. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + # Return cached value, defaulting to False if not yet populated + # The cache is populated during list_tools_sync() + return self._server_task_capable or False + + def _get_tool_task_support(self, tool_name: str) -> str | None: + """Get the taskSupport setting for a tool. + + Returns the cached taskSupport value for the given tool name. + The cache is populated during list_tools_sync(). + + Args: + tool_name: Name of the tool to look up. + + Returns: + The taskSupport value ('required', 'optional', 'forbidden') or None if not cached. + """ + return self._tool_task_support_cache.get(tool_name) + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Implements the MCP spec decision matrix: + - If server doesn't support tasks: MUST NOT use tasks (returns False) + - If tool taskSupport is None or 'forbidden': MUST NOT use tasks (returns False) + - If tool taskSupport is 'required' and server supports: use tasks (returns True) + - If tool taskSupport is 'optional' and server supports: prefer tasks (returns True) + + Per MCP spec, server capability check takes precedence over tool-level settings. + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Server capability check comes first (per MCP spec) + if not self._has_server_task_support(): + return False + + task_support = self._get_tool_task_support(tool_name) + + # Use tasks for 'required' or 'optional' when server supports + if task_support == "required" or task_support == "optional": + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _convert_timeout_for_polling(self, read_timeout_seconds: timedelta | None) -> float | None: + """Convert a timedelta timeout to seconds for task polling. + + When task-augmented execution is used, the read_timeout_seconds parameter + (which is a timedelta) needs to be converted to a float for the polling timeout. + + Args: + read_timeout_seconds: Optional timedelta timeout from the call_tool API. + + Returns: + Float seconds if timeout was specified, None to use default. + """ + return read_timeout_seconds.total_seconds() if read_timeout_seconds else None + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl_ms: int | None = None, + poll_timeout_seconds: float | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl_ms: Task time-to-live in milliseconds. Uses default_task_ttl_ms if not specified. + poll_timeout_seconds: Timeout for polling in seconds. Uses default_task_poll_timeout_seconds if not + specified. + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + session = cast(ClientSession, self._background_thread_session) + ttl = ttl_ms or self._default_task_ttl_ms + timeout = poll_timeout_seconds or self._default_task_poll_timeout_seconds + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> Any: + """Inner function to poll task status until terminal state.""" + final = None + async for status in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + status.status, + ) + final = status + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout=<%s> | task polling timed out", name, task_id, timeout + ) + return self._create_task_error_result(f"Task {task_id} polling timed out after {timeout} seconds") + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == "failed": + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == "cancelled": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == "completed": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py new file mode 100644 index 000000000..0cfce470a --- /dev/null +++ b/tests/strands/tools/mcp/conftest.py @@ -0,0 +1,59 @@ +"""Shared fixtures and helpers for MCP client tests.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_transport(): + """Create a mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create a mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + # Default: no task support (get_server_capabilities is sync, not async!) + mock_session.get_server_capabilities = MagicMock(return_value=None) + + # Create a mock context manager for ClientSession + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + # Patch ClientSession to return our mock session + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +def create_server_capabilities(has_task_support: bool) -> MagicMock: + """Create mock server capabilities. + + Args: + has_task_support: Whether the server should advertise task support. + + Returns: + MagicMock representing server capabilities. + """ + caps = MagicMock() + if has_task_support: + caps.tasks = MagicMock() + caps.tasks.requests = MagicMock() + caps.tasks.requests.tools = MagicMock() + caps.tasks.requests.tools.call = MagicMock() + else: + caps.tasks = None + return caps diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f784da414..4c9ca6752 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from mcp import ListToolsResult @@ -25,35 +25,7 @@ from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError - -@pytest.fixture -def mock_transport(): - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_transport_cm = AsyncMock() - mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) - mock_transport_callable = MagicMock(return_value=mock_transport_cm) - - return { - "read_stream": mock_read_stream, - "write_stream": mock_write_stream, - "transport_cm": mock_transport_cm, - "transport_callable": mock_transport_callable, - } - - -@pytest.fixture -def mock_session(): - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - - # Create a mock context manager for ClientSession - mock_session_cm = AsyncMock() - mock_session_cm.__aenter__.return_value = mock_session - - # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): - yield mock_session +# Fixtures mock_transport and mock_session are imported from conftest.py @pytest.fixture diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..6ce53f292 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -0,0 +1,223 @@ +"""Tests for MCP task-augmented execution support in MCPClient. + +These unit tests focus on error handling and edge cases that are not easily +testable through integration tests. Happy-path flows are covered by +integration tests in tests_integ/mcp/test_mcp_client_tasks.py. +""" + +import asyncio +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mcp import ListToolsResult +from mcp.types import Tool as MCPTool +from mcp.types import ToolExecution + +from strands.tools.mcp import MCPClient + +from .conftest import create_server_capabilities + + +class TestTaskExecutionFailures: + """Tests for task execution failure handling.""" + + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ], + ) + def test_task_execution_terminal_status(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses (failed, cancelled).""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + mock_status = MagicMock() + mock_status.status = status + mock_status.statusMessage = status_message + + async def mock_poll_task(task_id): + yield mock_status + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + + +class TestStopResetCache: + """Tests for cache reset in stop().""" + + def test_stop_resets_task_caches(self, mock_transport, mock_session): + """Test that stop() resets the task support caches.""" + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + client._tool_task_support_cache["tool1"] = "required" + + assert client._server_task_capable is None + assert client._tool_task_support_cache == {} + + +class TestTaskConfiguration: + """Tests for task-related configuration options.""" + + def test_default_task_config_values(self, mock_transport, mock_session): + """Test default configuration values.""" + with MCPClient(mock_transport["transport_callable"]) as client: + assert client._default_task_ttl_ms == 60000 + assert client._default_task_poll_timeout_seconds == 300.0 + + def test_custom_task_config_values(self, mock_transport, mock_session): + """Test custom configuration values.""" + with MCPClient( + mock_transport["transport_callable"], + default_task_ttl_ms=120000, + default_task_poll_timeout_seconds=60.0, + ) as client: + assert client._default_task_ttl_ms == 120000 + assert client._default_task_poll_timeout_seconds == 60.0 + + +class TestTaskExecutionTimeout: + """Tests for task execution timeout and error handling.""" + + def _setup_task_tool(self, mock_session, tool_name: str) -> None: + """Helper to set up a mock task-enabled tool.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + @pytest.mark.asyncio + async def test_task_polling_timeout(self, mock_transport, mock_session): + """Test that task polling times out properly.""" + self._setup_task_tool(mock_session, "slow_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=0.1) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) + + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) + + assert result["status"] == "error" + assert "result retrieval failed" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): + """Test that read_timeout_seconds overrides the default poll timeout.""" + self._setup_task_tool(mock_session, "timeout_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + # Long default timeout, but short explicit timeout + with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=300.0) as client: + client.list_tools_sync() + result = await client.call_tool_async( + tool_use_id="test-timeout", + name="timeout_tool", + arguments={}, + read_timeout_seconds=timedelta(seconds=0.1), + ) + + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_polling_yields_no_status(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing (final_status is None).""" + self._setup_task_tool(mock_session, "empty_poll_tool") + + async def empty_poll(task_id): + return + yield # noqa: B901 - makes this an async generator + + mock_session.experimental.poll_task = empty_poll + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + assert result["status"] == "error" + assert "without status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_unexpected_terminal_status(self, mock_transport, mock_session): + """Test handling of unexpected task status (not completed/failed/cancelled).""" + self._setup_task_tool(mock_session, "weird_tool") + + async def poll(task_id): + yield MagicMock(status="unknown_status", statusMessage=None) + + mock_session.experimental.poll_task = poll + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) + assert result["status"] == "error" + assert "unexpected task status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_successful_completion(self, mock_transport, mock_session): + """Test successful task completion with result retrieval (happy path).""" + from mcp.types import CallToolResult as MCPCallToolResult + from mcp.types import TextContent as MCPTextContent + + self._setup_task_tool(mock_session, "success_tool") + + async def poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) + assert result["status"] == "success" + assert "Done" in result["content"][0].get("text", "") diff --git a/tests_integ/mcp/task_echo_server.py b/tests_integ/mcp/task_echo_server.py new file mode 100644 index 000000000..4a8edc97d --- /dev/null +++ b/tests_integ/mcp/task_echo_server.py @@ -0,0 +1,139 @@ +"""MCP server with task-augmented tool execution support for integration testing.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + + +def create_task_server() -> Server: + """Create and configure the task-supporting MCP server.""" + server = Server("task-echo-server") + server.experimental.enable_tasks() + + # Workaround: MCP Python SDK's enable_tasks() doesn't properly set tasks.requests.tools.call capability + original_update_capabilities = server.experimental.update_capabilities + + def patched_update_capabilities(capabilities: types.ServerCapabilities) -> None: + original_update_capabilities(capabilities) + if capabilities.tasks and capabilities.tasks.requests and capabilities.tasks.requests.tools: + capabilities.tasks.requests.tools.call = types.TasksCallCapability() + + server.experimental.update_capabilities = patched_update_capabilities # type: ignore[method-assign] + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="task_required_echo", + description="Echo that requires task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="task_optional_echo", + description="Echo that optionally supports task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_OPTIONAL), + ), + types.Tool( + name="task_forbidden_echo", + description="Echo that does not support task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_FORBIDDEN), + ), + types.Tool( + name="echo", + description="Simple echo without task support setting", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ), + ] + + async def handle_task_required_echo(arguments: dict[str, Any]) -> types.CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + message = arguments.get("message", "") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing echo...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Task echo: {message}")]) + + return await ctx.experimental.run_task(work) + + async def handle_task_optional_echo(arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + ctx = server.request_context + message = arguments.get("message", "") + + if ctx.experimental.is_task: + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing optional task echo...") + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Task optional echo: {message}")] + ) + + return await ctx.experimental.run_task(work) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Direct optional echo: {message}")] + ) + + async def handle_task_forbidden_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Forbidden echo: {message}")]) + + async def handle_simple_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Simple echo: {message}")]) + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + handlers = { + "task_required_echo": handle_task_required_echo, + "task_optional_echo": handle_task_optional_echo, + "task_forbidden_echo": handle_task_forbidden_echo, + "echo": handle_simple_echo, + } + if name in handlers: + return await handlers[name](arguments) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], isError=True + ) + + return server + + +def create_starlette_app(port: int) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Create the Starlette app with MCP session manager.""" + server = create_task_server() + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=app_lifespan), session_manager + + +@click.command() +@click.option("--port", default=8010, help="Port to listen on") +def main(port: int) -> int: + """Start the task echo server.""" + import uvicorn + + starlette_app, _ = create_starlette_app(port) + print(f"Starting task echo server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..a294246f4 --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -0,0 +1,188 @@ +"""Integration tests for MCP task-augmented tool execution. + +These tests verify that our MCPClient correctly handles tools with taskSupport settings +and integrates with MCP servers that support task-augmented execution. + +The test server (task_echo_server.py) includes a workaround for an MCP Python SDK bug +where `enable_tasks()` doesn't properly set `tasks.requests.tools.call` capability. +""" + +import os +import socket +import threading +import time +from typing import Any + +import pytest +from mcp.client.streamable_http import streamablehttp_client + +from strands.tools.mcp.mcp_client import MCPClient +from strands.tools.mcp.mcp_types import MCPTransport + + +def _find_available_port() -> int: + """Find an available port by binding to port 0 and letting the OS assign one.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +def start_task_server(port: int) -> None: + """Start the task echo server in a thread.""" + import uvicorn + + from tests_integ.mcp.task_echo_server import create_starlette_app + + starlette_app, _ = create_starlette_app(port) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="warning") + + +@pytest.fixture(scope="module") +def task_server_port() -> int: + """Get a dynamically allocated port for the task server.""" + return _find_available_port() + + +@pytest.fixture(scope="module") +def task_server(task_server_port: int) -> Any: + """Start the task server for the test module.""" + server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) + server_thread.start() + time.sleep(2) # Wait for server to start + yield + # Server thread is daemon, will be cleaned up automatically + + +@pytest.fixture +def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client connected to the task server.""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) + + +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions", +) +class TestMCPTaskSupport: + """Integration tests for MCP task-augmented execution. + + These tests verify our client correctly: + 1. Detects server task capability and uses task-augmented execution when appropriate + 2. Caches taskSupport settings from tools + 3. Falls back to direct call_tool for tools that don't support tasks + 4. Handles the full task workflow (call_tool_as_task -> poll_task -> get_task_result) + """ + + def test_task_forbidden_tool_uses_direct_call(self, task_mcp_client: MCPClient) -> None: + """Test that a tool with taskSupport='forbidden' uses direct call_tool.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_forbidden_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-1", name="task_forbidden_echo", arguments={"message": "Hello forbidden!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Hello forbidden!" in result["content"][0].get("text", "") + + def test_tool_without_task_support_uses_direct_call(self, task_mcp_client: MCPClient) -> None: + """Test that a tool without taskSupport setting uses direct call_tool.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-2", name="echo", arguments={"message": "Hello simple!"} + ) + assert result["status"] == "success" + assert "Simple echo: Hello simple!" in result["content"][0].get("text", "") + + def test_tool_task_support_caching(self, task_mcp_client: MCPClient) -> None: + """Test that tool taskSupport values are cached during list_tools.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" + assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" + assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" + assert task_mcp_client._get_tool_task_support("echo") is None + + def test_server_capabilities_advertised(self, task_mcp_client: MCPClient) -> None: + """Test that server properly advertises task capabilities.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + session = task_mcp_client._background_thread_session + if session: + caps = session.get_server_capabilities() + assert caps is not None and caps.tasks is not None + assert caps.tasks.requests is not None and caps.tasks.requests.tools is not None + assert caps.tasks.requests.tools.call is not None + assert task_mcp_client._has_server_task_support() is True + + def test_task_required_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: + """Test that task-required tools use task-augmented execution.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_required_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-3", name="task_required_echo", arguments={"message": "Hello from task!"} + ) + assert result["status"] == "success" + assert "Task echo: Hello from task!" in result["content"][0].get("text", "") + + def test_task_optional_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: + """Test that task-optional tools use task-augmented execution when server supports it.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_optional_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-4", name="task_optional_echo", arguments={"message": "Hello optional task!"} + ) + assert result["status"] == "success" + assert "Task optional echo: Hello optional task!" in result["content"][0].get("text", "") + + def test_should_use_task_logic_with_server_support(self, task_mcp_client: MCPClient) -> None: + """Test that _should_use_task returns correct values based on tool taskSupport.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._should_use_task("task_required_echo") is True + assert task_mcp_client._should_use_task("task_optional_echo") is True + assert task_mcp_client._should_use_task("task_forbidden_echo") is False + assert task_mcp_client._should_use_task("echo") is False + + def test_multiple_tool_calls_in_sequence(self, task_mcp_client: MCPClient) -> None: + """Test calling multiple tools in sequence with different task modes.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + r1 = task_mcp_client.call_tool_sync( + tool_use_id="s1", name="task_forbidden_echo", arguments={"message": "1"} + ) + assert r1["status"] == "success" and "Forbidden echo: 1" in r1["content"][0].get("text", "") + + r2 = task_mcp_client.call_tool_sync(tool_use_id="s2", name="echo", arguments={"message": "2"}) + assert r2["status"] == "success" and "Simple echo: 2" in r2["content"][0].get("text", "") + + r3 = task_mcp_client.call_tool_sync(tool_use_id="s3", name="task_optional_echo", arguments={"message": "3"}) + assert r3["status"] == "success" and "Task optional echo: 3" in r3["content"][0].get("text", "") + + r4 = task_mcp_client.call_tool_sync(tool_use_id="s4", name="task_required_echo", arguments={"message": "4"}) + assert r4["status"] == "success" and "Task echo: 4" in r4["content"][0].get("text", "") + + @pytest.mark.asyncio + async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls work correctly.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="test-async", name="task_forbidden_echo", arguments={"message": "Async hello!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") From a4a5ac7aedf7ea3e5607ddef9c097b917c5248e5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 22 Jan 2026 14:47:15 -0800 Subject: [PATCH 02/10] chore: cache server task capability immediately --- src/strands/tools/mcp/mcp_client.py | 48 ++++++++----------- .../tools/mcp/test_mcp_client_contextvar.py | 2 + 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2ca284407..ac16d4800 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -406,27 +406,10 @@ def list_tools_sync( effective_prefix = self._prefix if prefix is None else prefix effective_filters = self._tool_filters if tool_filters is None else tool_filters - async def _list_tools_and_cache_capabilities_async() -> ListToolsResult: - session = cast(ClientSession, self._background_thread_session) - list_tools_result = await session.list_tools(cursor=pagination_token) - - # Cache server task capability while we have an active session - # This avoids needing a separate async call later during call_tool_* - if self._server_task_capable is None: - caps = session.get_server_capabilities() - self._server_task_capable = ( - caps is not None - and caps.tasks is not None - and caps.tasks.requests is not None - and caps.tasks.requests.tools is not None - and caps.tasks.requests.tools.call is not None - ) - - return list_tools_result + async def _list_tools_async() -> ListToolsResult: + return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) - list_tools_response: ListToolsResult = self._invoke_on_background_thread( - _list_tools_and_cache_capabilities_async() - ).result() + list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) mcp_tools = [] @@ -753,6 +736,21 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("session initialized successfully") # Store the session for use while we await the close event self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + # Signal that the session has been created and is ready for use self._init_future.set_result(None) @@ -968,17 +966,13 @@ def _is_session_active(self) -> bool: def _has_server_task_support(self) -> bool: """Check if the MCP server supports task-augmented tool calls. - Returns the cached capability value that was populated during list_tools_sync(). - If list_tools_sync() hasn't been called yet, returns False (conservative default). - - The capability is cached during list_tools_sync() to avoid needing a separate - async call during call_tool_*() operations. + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. Returns: True if server supports task-augmented tool calls, False otherwise. """ - # Return cached value, defaulting to False if not yet populated - # The cache is populated during list_tools_sync() return self._server_task_capable or False def _get_tool_task_support(self, tool_name: str) -> str | None: diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py index d95929b02..739796366 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -37,6 +37,8 @@ def mock_session(): """Create mock MCP session.""" mock_session = AsyncMock() mock_session.initialize = AsyncMock() + # get_server_capabilities is sync, not async + mock_session.get_server_capabilities = MagicMock(return_value=None) mock_session_cm = AsyncMock() mock_session_cm.__aenter__.return_value = mock_session From 6801cdf5da9026c4b742b2e494e28f5af2758ec3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 23 Jan 2026 14:24:40 -0800 Subject: [PATCH 03/10] chore: add experimental.tasks feature gate --- src/strands/tools/mcp/mcp_client.py | 107 ++++++++++++++---- .../tools/mcp/test_mcp_client_tasks.py | 74 +++++++++--- tests_integ/mcp/test_mcp_client_tasks.py | 32 +++++- 3 files changed, 174 insertions(+), 39 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ac16d4800..0cfaae790 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -73,6 +73,39 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + If this config is provided (not None), task-augmented execution is enabled. + When enabled, long-running tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Attributes: + ttl_ms: Task time-to-live in milliseconds. Defaults to 60000 (1 minute). + poll_timeout_seconds: Timeout for polling task completion in seconds. + Defaults to 300.0 (5 minutes). + """ + + ttl_ms: int + poll_timeout_seconds: float + + +class ExperimentalConfig(TypedDict, total=False): + """Configuration for experimental MCPClient features. + + Warning: + Features under this configuration are experimental and subject to change + in future revisions without notice. + + Attributes: + tasks: Configuration for MCP Tasks (task-augmented tool execution). + If provided (not None), enables task-augmented execution for tools + that support it. + """ + + tasks: TasksConfig | None + + MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -120,8 +153,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, - default_task_ttl_ms: int = 60000, - default_task_poll_timeout_seconds: float = 300.0, + experimental: ExperimentalConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,10 +164,10 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. - default_task_ttl_ms: Default time-to-live in milliseconds for task-augmented tool calls. - Defaults to 60000 (1 minute). - default_task_poll_timeout_seconds: Default timeout in seconds for polling task completion. - Defaults to 300.0 (5 minutes). + experimental: Configuration for experimental features. Currently supports: + - tasks: Enable MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools + that support it. See ExperimentalConfig and TasksConfig for details. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -160,9 +192,8 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() - # Task support caching - self._default_task_ttl_ms = default_task_ttl_ms - self._default_task_poll_timeout_seconds = default_task_poll_timeout_seconds + # Task support configuration and caching + self._experimental = experimental or {} self._server_task_capable: bool | None = None self._tool_task_support_cache: dict[str, str | None] = {} @@ -963,6 +994,38 @@ def _is_session_active(self) -> bool: return True + def _is_tasks_enabled(self) -> bool: + """Check if experimental tasks feature is enabled. + + Tasks are enabled if experimental.tasks is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._experimental.get("tasks") is not None + + def _get_task_ttl_ms(self) -> int: + """Get task TTL in milliseconds. + + Returns: + Task TTL from config, or default of 60000 (1 minute). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 60000 + return tasks_config.get("ttl_ms", 60000) + + def _get_task_poll_timeout_seconds(self) -> float: + """Get task polling timeout in seconds. + + Returns: + Polling timeout from config, or default of 300.0 (5 minutes). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 300.0 + return tasks_config.get("poll_timeout_seconds", 300.0) + def _has_server_task_support(self) -> bool: """Check if the MCP server supports task-augmented tool calls. @@ -992,13 +1055,10 @@ def _get_tool_task_support(self, tool_name: str) -> str | None: def _should_use_task(self, tool_name: str) -> bool: """Determine if task-augmented execution should be used for a tool. - Implements the MCP spec decision matrix: - - If server doesn't support tasks: MUST NOT use tasks (returns False) - - If tool taskSupport is None or 'forbidden': MUST NOT use tasks (returns False) - - If tool taskSupport is 'required' and server supports: use tasks (returns True) - - If tool taskSupport is 'optional' and server supports: prefer tasks (returns True) - - Per MCP spec, server capability check takes precedence over tool-level settings. + Task-augmented execution requires: + 1. experimental.tasks is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' Args: tool_name: Name of the tool to check. @@ -1006,7 +1066,11 @@ def _should_use_task(self, tool_name: str) -> bool: Returns: True if task-augmented execution should be used, False otherwise. """ - # Server capability check comes first (per MCP spec) + # Opt-in check: tasks must be explicitly enabled via experimental.tasks + if not self._is_tasks_enabled(): + return False + + # Server capability check (per MCP spec) if not self._has_server_task_support(): return False @@ -1079,16 +1143,15 @@ async def _call_tool_as_task_and_poll_async( Args: name: Name of the tool to call. arguments: Optional arguments to pass to the tool. - ttl_ms: Task time-to-live in milliseconds. Uses default_task_ttl_ms if not specified. - poll_timeout_seconds: Timeout for polling in seconds. Uses default_task_poll_timeout_seconds if not - specified. + ttl_ms: Task time-to-live in milliseconds. Uses configured value if not specified. + poll_timeout_seconds: Timeout for polling in seconds. Uses configured value if not specified. Returns: MCPCallToolResult: The final tool result after task completion. """ session = cast(ClientSession, self._background_thread_session) - ttl = ttl_ms or self._default_task_ttl_ms - timeout = poll_timeout_seconds or self._default_task_poll_timeout_seconds + ttl = ttl_ms or self._get_task_ttl_ms() + timeout = poll_timeout_seconds or self._get_task_poll_timeout_seconds() # Step 1: Create the task self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 6ce53f292..2dd4908a5 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -19,6 +19,38 @@ from .conftest import create_server_capabilities +class TestTasksDisabledByDefault: + """Tests that tasks are disabled by default.""" + + def test_tasks_disabled_when_no_experimental_config(self, mock_transport, mock_session): + """Test that _should_use_task returns False when experimental.tasks is not configured.""" + with MCPClient(mock_transport["transport_callable"]) as client: + # Even with server capability and tool support, tasks should be disabled + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is False + assert client._should_use_task("test_tool") is False + + def test_tasks_disabled_when_experimental_tasks_is_none(self, mock_transport, mock_session): + """Test that _should_use_task returns False when experimental.tasks is explicitly None.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": None}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is False + assert client._should_use_task("test_tool") is False + + def test_tasks_enabled_when_experimental_tasks_is_empty_dict(self, mock_transport, mock_session): + """Test that tasks are enabled when experimental.tasks is an empty dict.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is True + assert client._should_use_task("test_tool") is True + + class TestTaskExecutionFailures: """Tests for task execution failure handling.""" @@ -44,7 +76,7 @@ async def mock_poll_task(task_id): mock_session.experimental.poll_task = mock_poll_task - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) @@ -58,7 +90,7 @@ class TestStopResetCache: def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" @@ -71,19 +103,27 @@ class TestTaskConfiguration: def test_default_task_config_values(self, mock_transport, mock_session): """Test default configuration values.""" - with MCPClient(mock_transport["transport_callable"]) as client: - assert client._default_task_ttl_ms == 60000 - assert client._default_task_poll_timeout_seconds == 300.0 + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + assert client._get_task_ttl_ms() == 60000 + assert client._get_task_poll_timeout_seconds() == 300.0 def test_custom_task_config_values(self, mock_transport, mock_session): """Test custom configuration values.""" with MCPClient( mock_transport["transport_callable"], - default_task_ttl_ms=120000, - default_task_poll_timeout_seconds=60.0, + experimental={"tasks": {"ttl_ms": 120000, "poll_timeout_seconds": 60.0}}, + ) as client: + assert client._get_task_ttl_ms() == 120000 + assert client._get_task_poll_timeout_seconds() == 60.0 + + def test_partial_task_config_uses_defaults(self, mock_transport, mock_session): + """Test that partial config uses defaults for unspecified values.""" + with MCPClient( + mock_transport["transport_callable"], + experimental={"tasks": {"ttl_ms": 120000}}, ) as client: - assert client._default_task_ttl_ms == 120000 - assert client._default_task_poll_timeout_seconds == 60.0 + assert client._get_task_ttl_ms() == 120000 + assert client._get_task_poll_timeout_seconds() == 300.0 # default class TestTaskExecutionTimeout: @@ -117,7 +157,9 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll - with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=0.1) as client: + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} + ) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) @@ -135,7 +177,7 @@ async def successful_poll(task_id): mock_session.experimental.poll_task = successful_poll mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) @@ -155,7 +197,9 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll # Long default timeout, but short explicit timeout - with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=300.0) as client: + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} + ) as client: client.list_tools_sync() result = await client.call_tool_async( tool_use_id="test-timeout", @@ -178,7 +222,7 @@ async def empty_poll(task_id): mock_session.experimental.poll_task = empty_poll - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" @@ -194,7 +238,7 @@ async def poll(task_id): mock_session.experimental.poll_task = poll - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) assert result["status"] == "error" @@ -216,7 +260,7 @@ async def poll(task_id): return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) ) - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) assert result["status"] == "success" diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index a294246f4..892dfb90e 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -57,12 +57,22 @@ def task_server(task_server_port: int) -> Any: @pytest.fixture def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server.""" + """Create an MCP client connected to the task server with tasks enabled.""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback) + return MCPClient(transport_callback, experimental={"tasks": {}}) + + +@pytest.fixture +def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client connected to the task server with tasks disabled (default).""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) # No experimental config - tasks disabled @pytest.mark.skipif( @@ -186,3 +196,21 @@ async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: ) assert result["status"] == "success" assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") + + def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: + """Test that tasks are disabled when experimental.tasks is not configured.""" + with task_mcp_client_disabled: + task_mcp_client_disabled.list_tools_sync() + + # Even though server supports tasks and tool has taskSupport='required', + # tasks should NOT be used because experimental.tasks is not configured + assert task_mcp_client_disabled._is_tasks_enabled() is False + assert task_mcp_client_disabled._should_use_task("task_required_echo") is False + assert task_mcp_client_disabled._should_use_task("task_optional_echo") is False + + # Tool calls should still work via direct call_tool + result = task_mcp_client_disabled.call_tool_sync( + tool_use_id="test-disabled", name="task_required_echo", arguments={"message": "Direct call!"} + ) + assert result["status"] == "success" + assert "Task echo: Direct call!" in result["content"][0].get("text", "") From ceaea6ac5fd54fe3376c82ecd3ec8719e280eaa3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 23 Jan 2026 14:32:11 -0800 Subject: [PATCH 04/10] chore: parameterize tests --- .../tools/mcp/test_mcp_client_tasks.py | 223 +++++++----------- tests_integ/mcp/test_mcp_client_tasks.py | 178 +++++--------- 2 files changed, 143 insertions(+), 258 deletions(-) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 2dd4908a5..629163695 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -1,9 +1,4 @@ -"""Tests for MCP task-augmented execution support in MCPClient. - -These unit tests focus on error handling and edge cases that are not easily -testable through integration tests. Happy-path flows are covered by -integration tests in tests_integ/mcp/test_mcp_client_tasks.py. -""" +"""Tests for MCP task-augmented execution support in MCPClient.""" import asyncio from datetime import timedelta @@ -11,6 +6,8 @@ import pytest from mcp import ListToolsResult +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool from mcp.types import ToolExecution @@ -19,115 +16,66 @@ from .conftest import create_server_capabilities -class TestTasksDisabledByDefault: - """Tests that tasks are disabled by default.""" +class TestTasksOptIn: + """Tests for task opt-in behavior via experimental.tasks.""" - def test_tasks_disabled_when_no_experimental_config(self, mock_transport, mock_session): - """Test that _should_use_task returns False when experimental.tasks is not configured.""" - with MCPClient(mock_transport["transport_callable"]) as client: - # Even with server capability and tool support, tasks should be disabled - client._server_task_capable = True - client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is False - assert client._should_use_task("test_tool") is False + @pytest.mark.parametrize( + "experimental,expected_enabled", + [ + (None, False), + ({}, False), + ({"tasks": None}, False), + ({"tasks": {}}, True), + ({"tasks": {"ttl_ms": 1000}}, True), + ], + ) + def test_tasks_enabled_state(self, mock_transport, mock_session, experimental, expected_enabled): + """Test _is_tasks_enabled based on experimental config.""" + with MCPClient(mock_transport["transport_callable"], experimental=experimental) as client: + assert client._is_tasks_enabled() is expected_enabled - def test_tasks_disabled_when_experimental_tasks_is_none(self, mock_transport, mock_session): - """Test that _should_use_task returns False when experimental.tasks is explicitly None.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": None}) as client: + def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): + """Test that _should_use_task returns False without opt-in even with server/tool support.""" + with MCPClient(mock_transport["transport_callable"]) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is False assert client._should_use_task("test_tool") is False - def test_tasks_enabled_when_experimental_tasks_is_empty_dict(self, mock_transport, mock_session): - """Test that tasks are enabled when experimental.tasks is an empty dict.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is True assert client._should_use_task("test_tool") is True -class TestTaskExecutionFailures: - """Tests for task execution failure handling.""" +class TestTaskConfiguration: + """Tests for task-related configuration options.""" @pytest.mark.parametrize( - "status,status_message,expected_text", + "config,expected_ttl,expected_timeout", [ - ("failed", "Something went wrong", "Something went wrong"), - ("cancelled", None, "cancelled"), + ({}, 60000, 300.0), + ({"ttl_ms": 120000}, 120000, 300.0), + ({"poll_timeout_seconds": 60.0}, 60000, 60.0), + ({"ttl_ms": 120000, "poll_timeout_seconds": 60.0}, 120000, 60.0), ], ) - def test_task_execution_terminal_status(self, mock_transport, mock_session, status, status_message, expected_text): - """Test handling of terminal task statuses (failed, cancelled).""" - mock_create_result = MagicMock() - mock_create_result.task.taskId = f"task-{status}" - mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) - - mock_status = MagicMock() - mock_status.status = status - mock_status.statusMessage = status_message - - async def mock_poll_task(task_id): - yield mock_status - - mock_session.experimental.poll_task = mock_poll_task - - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - client._server_task_capable = True - client._tool_task_support_cache["test_tool"] = "required" - result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) - - assert result["status"] == "error" - assert expected_text.lower() in result["content"][0].get("text", "").lower() - - -class TestStopResetCache: - """Tests for cache reset in stop().""" + def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): + """Test task configuration values with various configs.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: + assert client._get_task_ttl_ms() == expected_ttl + assert client._get_task_poll_timeout_seconds() == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" - assert client._server_task_capable is None assert client._tool_task_support_cache == {} -class TestTaskConfiguration: - """Tests for task-related configuration options.""" - - def test_default_task_config_values(self, mock_transport, mock_session): - """Test default configuration values.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - assert client._get_task_ttl_ms() == 60000 - assert client._get_task_poll_timeout_seconds() == 300.0 - - def test_custom_task_config_values(self, mock_transport, mock_session): - """Test custom configuration values.""" - with MCPClient( - mock_transport["transport_callable"], - experimental={"tasks": {"ttl_ms": 120000, "poll_timeout_seconds": 60.0}}, - ) as client: - assert client._get_task_ttl_ms() == 120000 - assert client._get_task_poll_timeout_seconds() == 60.0 - - def test_partial_task_config_uses_defaults(self, mock_transport, mock_session): - """Test that partial config uses defaults for unspecified values.""" - with MCPClient( - mock_transport["transport_callable"], - experimental={"tasks": {"ttl_ms": 120000}}, - ) as client: - assert client._get_task_ttl_ms() == 120000 - assert client._get_task_poll_timeout_seconds() == 300.0 # default - - -class TestTaskExecutionTimeout: - """Tests for task execution timeout and error handling.""" +class TestTaskExecution: + """Tests for task execution and error handling.""" def _setup_task_tool(self, mock_session, tool_name: str) -> None: """Helper to set up a mock task-enabled tool.""" @@ -139,14 +87,39 @@ def _setup_task_tool(self, mock_session, tool_name: str) -> None: execution=ToolExecution(taskSupport="optional"), ) mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) - mock_create_result = MagicMock() mock_create_result.task.taskId = "test-task-id" mock_session.experimental = MagicMock() mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ("unknown_status", None, "unexpected task status"), + ], + ) + def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses.""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def mock_poll_task(task_id): + yield MagicMock(status=status, statusMessage=status_message) + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + @pytest.mark.asyncio - async def test_task_polling_timeout(self, mock_transport, mock_session): + async def test_polling_timeout(self, mock_transport, mock_session): """Test that task polling times out properly.""" self._setup_task_tool(mock_session, "slow_tool") @@ -161,29 +134,10 @@ async def infinite_poll(task_id): mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} ) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) - + result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) assert result["status"] == "error" assert "timed out" in result["content"][0].get("text", "").lower() - @pytest.mark.asyncio - async def test_task_result_retrieval_failure(self, mock_transport, mock_session): - """Test that get_task_result failures are handled gracefully.""" - self._setup_task_tool(mock_session, "failing_tool") - - async def successful_poll(task_id): - yield MagicMock(status="completed", statusMessage=None) - - mock_session.experimental.poll_task = successful_poll - mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) - - assert result["status"] == "error" - assert "result retrieval failed" in result["content"][0].get("text", "").lower() - @pytest.mark.asyncio async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): """Test that read_timeout_seconds overrides the default poll timeout.""" @@ -196,60 +150,53 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll - # Long default timeout, but short explicit timeout with MCPClient( mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} ) as client: client.list_tools_sync() result = await client.call_tool_async( - tool_use_id="test-timeout", - name="timeout_tool", - arguments={}, - read_timeout_seconds=timedelta(seconds=0.1), + tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1) ) - assert result["status"] == "error" assert "timed out" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_polling_yields_no_status(self, mock_transport, mock_session): - """Test handling when poll_task yields nothing (final_status is None).""" - self._setup_task_tool(mock_session, "empty_poll_tool") + async def test_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") - async def empty_poll(task_id): - return - yield # noqa: B901 - makes this an async generator + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) - mock_session.experimental.poll_task = empty_poll + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) assert result["status"] == "error" - assert "without status" in result["content"][0].get("text", "").lower() + assert "result retrieval failed" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_unexpected_terminal_status(self, mock_transport, mock_session): - """Test handling of unexpected task status (not completed/failed/cancelled).""" - self._setup_task_tool(mock_session, "weird_tool") + async def test_empty_poll_result(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing.""" + self._setup_task_tool(mock_session, "empty_poll_tool") - async def poll(task_id): - yield MagicMock(status="unknown_status", statusMessage=None) + async def empty_poll(task_id): + return + yield # noqa: B901 - mock_session.experimental.poll_task = poll + mock_session.experimental.poll_task = empty_poll with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" - assert "unexpected task status" in result["content"][0].get("text", "").lower() + assert "without status" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_successful_completion(self, mock_transport, mock_session): - """Test successful task completion with result retrieval (happy path).""" - from mcp.types import CallToolResult as MCPCallToolResult - from mcp.types import TextContent as MCPTextContent - + async def test_successful_completion(self, mock_transport, mock_session): + """Test successful task completion.""" self._setup_task_tool(mock_session, "success_tool") async def poll(task_id): diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index 892dfb90e..5e398f6de 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -1,11 +1,4 @@ -"""Integration tests for MCP task-augmented tool execution. - -These tests verify that our MCPClient correctly handles tools with taskSupport settings -and integrates with MCP servers that support task-augmented execution. - -The test server (task_echo_server.py) includes a workaround for an MCP Python SDK bug -where `enable_tasks()` doesn't properly set `tasks.requests.tools.call` capability. -""" +"""Integration tests for MCP task-augmented tool execution.""" import os import socket @@ -21,12 +14,11 @@ def _find_available_port() -> int: - """Find an available port by binding to port 0 and letting the OS assign one.""" + """Find an available port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) s.listen(1) - port = s.getsockname()[1] - return port + return s.getsockname()[1] def start_task_server(port: int) -> None: @@ -41,7 +33,6 @@ def start_task_server(port: int) -> None: @pytest.fixture(scope="module") def task_server_port() -> int: - """Get a dynamically allocated port for the task server.""" return _find_available_port() @@ -50,14 +41,13 @@ def task_server(task_server_port: int) -> Any: """Start the task server for the test module.""" server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) server_thread.start() - time.sleep(2) # Wait for server to start + time.sleep(2) yield - # Server thread is daemon, will be cleaned up automatically @pytest.fixture def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server with tasks enabled.""" + """Create an MCP client with tasks enabled.""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") @@ -67,150 +57,98 @@ def transport_callback() -> MCPTransport: @pytest.fixture def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server with tasks disabled (default).""" + """Create an MCP client with tasks disabled (default).""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback) # No experimental config - tasks disabled + return MCPClient(transport_callback) -@pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == "true", - reason="streamable transport is failing in GitHub actions", -) +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport failing in CI") class TestMCPTaskSupport: - """Integration tests for MCP task-augmented execution. + """Integration tests for MCP task-augmented execution.""" - These tests verify our client correctly: - 1. Detects server task capability and uses task-augmented execution when appropriate - 2. Caches taskSupport settings from tools - 3. Falls back to direct call_tool for tools that don't support tasks - 4. Handles the full task workflow (call_tool_as_task -> poll_task -> get_task_result) - """ - - def test_task_forbidden_tool_uses_direct_call(self, task_mcp_client: MCPClient) -> None: - """Test that a tool with taskSupport='forbidden' uses direct call_tool.""" + def test_direct_call_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use direct call_tool (forbidden or no taskSupport).""" with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_forbidden_echo" in [t.tool_name for t in tools] + task_mcp_client.list_tools_sync() - result = task_mcp_client.call_tool_sync( - tool_use_id="test-1", name="task_forbidden_echo", arguments={"message": "Hello forbidden!"} + # Tool with taskSupport='forbidden' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_forbidden_echo", arguments={"message": "Hello!"} ) - assert result["status"] == "success" - assert "Forbidden echo: Hello forbidden!" in result["content"][0].get("text", "") + assert r1["status"] == "success" + assert "Forbidden echo: Hello!" in r1["content"][0].get("text", "") + + # Tool without taskSupport + r2 = task_mcp_client.call_tool_sync(tool_use_id="t2", name="echo", arguments={"message": "Simple!"}) + assert r2["status"] == "success" + assert "Simple echo: Simple!" in r2["content"][0].get("text", "") - def test_tool_without_task_support_uses_direct_call(self, task_mcp_client: MCPClient) -> None: - """Test that a tool without taskSupport setting uses direct call_tool.""" + def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use task-augmented execution (required or optional).""" with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "echo" in [t.tool_name for t in tools] + task_mcp_client.list_tools_sync() - result = task_mcp_client.call_tool_sync( - tool_use_id="test-2", name="echo", arguments={"message": "Hello simple!"} + # Tool with taskSupport='required' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_required_echo", arguments={"message": "Required!"} ) - assert result["status"] == "success" - assert "Simple echo: Hello simple!" in result["content"][0].get("text", "") + assert r1["status"] == "success" + assert "Task echo: Required!" in r1["content"][0].get("text", "") - def test_tool_task_support_caching(self, task_mcp_client: MCPClient) -> None: - """Test that tool taskSupport values are cached during list_tools.""" + # Tool with taskSupport='optional' + r2 = task_mcp_client.call_tool_sync( + tool_use_id="t2", name="task_optional_echo", arguments={"message": "Optional!"} + ) + assert r2["status"] == "success" + assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") + + def test_task_support_caching_and_decision(self, task_mcp_client: MCPClient) -> None: + """Test taskSupport caching and _should_use_task decision logic.""" with task_mcp_client: task_mcp_client.list_tools_sync() + + # Verify cached values assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" assert task_mcp_client._get_tool_task_support("echo") is None - def test_server_capabilities_advertised(self, task_mcp_client: MCPClient) -> None: - """Test that server properly advertises task capabilities.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() - session = task_mcp_client._background_thread_session - if session: - caps = session.get_server_capabilities() - assert caps is not None and caps.tasks is not None - assert caps.tasks.requests is not None and caps.tasks.requests.tools is not None - assert caps.tasks.requests.tools.call is not None - assert task_mcp_client._has_server_task_support() is True - - def test_task_required_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: - """Test that task-required tools use task-augmented execution.""" - with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_required_echo" in [t.tool_name for t in tools] - - result = task_mcp_client.call_tool_sync( - tool_use_id="test-3", name="task_required_echo", arguments={"message": "Hello from task!"} - ) - assert result["status"] == "success" - assert "Task echo: Hello from task!" in result["content"][0].get("text", "") - - def test_task_optional_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: - """Test that task-optional tools use task-augmented execution when server supports it.""" - with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_optional_echo" in [t.tool_name for t in tools] - - result = task_mcp_client.call_tool_sync( - tool_use_id="test-4", name="task_optional_echo", arguments={"message": "Hello optional task!"} - ) - assert result["status"] == "success" - assert "Task optional echo: Hello optional task!" in result["content"][0].get("text", "") - - def test_should_use_task_logic_with_server_support(self, task_mcp_client: MCPClient) -> None: - """Test that _should_use_task returns correct values based on tool taskSupport.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() + # Verify decision logic assert task_mcp_client._should_use_task("task_required_echo") is True assert task_mcp_client._should_use_task("task_optional_echo") is True assert task_mcp_client._should_use_task("task_forbidden_echo") is False assert task_mcp_client._should_use_task("echo") is False - def test_multiple_tool_calls_in_sequence(self, task_mcp_client: MCPClient) -> None: - """Test calling multiple tools in sequence with different task modes.""" + def test_server_capabilities(self, task_mcp_client: MCPClient) -> None: + """Test server task capability detection.""" with task_mcp_client: task_mcp_client.list_tools_sync() - - r1 = task_mcp_client.call_tool_sync( - tool_use_id="s1", name="task_forbidden_echo", arguments={"message": "1"} - ) - assert r1["status"] == "success" and "Forbidden echo: 1" in r1["content"][0].get("text", "") - - r2 = task_mcp_client.call_tool_sync(tool_use_id="s2", name="echo", arguments={"message": "2"}) - assert r2["status"] == "success" and "Simple echo: 2" in r2["content"][0].get("text", "") - - r3 = task_mcp_client.call_tool_sync(tool_use_id="s3", name="task_optional_echo", arguments={"message": "3"}) - assert r3["status"] == "success" and "Task optional echo: 3" in r3["content"][0].get("text", "") - - r4 = task_mcp_client.call_tool_sync(tool_use_id="s4", name="task_required_echo", arguments={"message": "4"}) - assert r4["status"] == "success" and "Task echo: 4" in r4["content"][0].get("text", "") - - @pytest.mark.asyncio - async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: - """Test async tool calls work correctly.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() - result = await task_mcp_client.call_tool_async( - tool_use_id="test-async", name="task_forbidden_echo", arguments={"message": "Async hello!"} - ) - assert result["status"] == "success" - assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") + assert task_mcp_client._has_server_task_support() is True def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: """Test that tasks are disabled when experimental.tasks is not configured.""" with task_mcp_client_disabled: task_mcp_client_disabled.list_tools_sync() - # Even though server supports tasks and tool has taskSupport='required', - # tasks should NOT be used because experimental.tasks is not configured assert task_mcp_client_disabled._is_tasks_enabled() is False assert task_mcp_client_disabled._should_use_task("task_required_echo") is False - assert task_mcp_client_disabled._should_use_task("task_optional_echo") is False - # Tool calls should still work via direct call_tool + # Tool calls still work via direct call_tool result = task_mcp_client_disabled.call_tool_sync( - tool_use_id="test-disabled", name="task_required_echo", arguments={"message": "Direct call!"} + tool_use_id="t", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="t", name="task_forbidden_echo", arguments={"message": "Async!"} ) assert result["status"] == "success" - assert "Task echo: Direct call!" in result["content"][0].get("text", "") + assert "Forbidden echo: Async!" in result["content"][0].get("text", "") From 3ed9bbbd73757d68f2886127091d77bde575086e Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 3 Feb 2026 11:24:14 -0800 Subject: [PATCH 05/10] chore: reuse Task types from MCP SDK --- src/strands/tools/mcp/mcp_client.py | 59 +++++++++---------- .../tools/mcp/test_mcp_client_tasks.py | 1 - 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 0cfaae790..7d152ea78 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -195,7 +195,12 @@ def __init__( # Task support configuration and caching self._experimental = experimental or {} self._server_task_capable: bool | None = None - self._tool_task_support_cache: dict[str, str | None] = {} + + # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) + if self._is_tasks_enabled(): + from mcp.types import TaskExecutionMode + + self._tool_task_support_cache: dict[str, TaskExecutionMode] = {} def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -445,11 +450,12 @@ async def _list_tools_async() -> ListToolsResult: mcp_tools = [] for tool in list_tools_response.tools: - # Cache taskSupport for task-augmented execution decisions - task_support = None - if tool.execution is not None and tool.execution.taskSupport is not None: - task_support = tool.execution.taskSupport - self._tool_task_support_cache[tool.name] = task_support + if self._is_tasks_enabled(): + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support or "forbidden" # Apply prefix if specified if effective_prefix: @@ -1038,20 +1044,6 @@ def _has_server_task_support(self) -> bool: """ return self._server_task_capable or False - def _get_tool_task_support(self, tool_name: str) -> str | None: - """Get the taskSupport setting for a tool. - - Returns the cached taskSupport value for the given tool name. - The cache is populated during list_tools_sync(). - - Args: - tool_name: Name of the tool to look up. - - Returns: - The taskSupport value ('required', 'optional', 'forbidden') or None if not cached. - """ - return self._tool_task_support_cache.get(tool_name) - def _should_use_task(self, tool_name: str) -> bool: """Determine if task-augmented execution should be used for a tool. @@ -1070,14 +1062,18 @@ def _should_use_task(self, tool_name: str) -> bool: if not self._is_tasks_enabled(): return False + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_OPTIONAL, TASK_REQUIRED + # Server capability check (per MCP spec) if not self._has_server_task_support(): return False - task_support = self._get_tool_task_support(tool_name) + # Tool-level capability check (cached during list_tools_sync) + task_support = self._tool_task_support_cache.get(tool_name) - # Use tasks for 'required' or 'optional' when server supports - if task_support == "required" or task_support == "optional": + # Use tasks for TASK_REQUIRED or TASK_OPTIONAL when server supports + if task_support == TASK_REQUIRED or task_support == TASK_OPTIONAL: return True # Default: 'forbidden', None, or unknown -> don't use tasks @@ -1149,6 +1145,9 @@ async def _call_tool_as_task_and_poll_async( Returns: MCPCallToolResult: The final tool result after task completion. """ + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult + session = cast(ClientSession, self._background_thread_session) ttl = ttl_ms or self._get_task_ttl_ms() timeout = poll_timeout_seconds or self._get_task_poll_timeout_seconds() @@ -1165,17 +1164,17 @@ async def _call_tool_as_task_and_poll_async( # Step 2: Poll until terminal status (with timeout protection) # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility - async def _poll_until_terminal() -> Any: + async def _poll_until_terminal() -> GetTaskResult | None: """Inner function to poll task status until terminal state.""" final = None - async for status in session.experimental.poll_task(task_id): + async for task in session.experimental.poll_task(task_id): self._log_debug_with_thread( "tool=<%s>, task_id=<%s>, status=<%s> | task status update", name, task_id, - status.status, + task.status, ) - final = status + final = task return final try: @@ -1191,17 +1190,17 @@ async def _poll_until_terminal() -> Any: self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) return self._create_task_error_result(f"Task {task_id} polling completed without status") - if final_status.status == "failed": + if final_status.status == TASK_STATUS_FAILED: error_msg = final_status.statusMessage or "Task failed" self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) return self._create_task_error_result(error_msg) - if final_status.status == "cancelled": + if final_status.status == TASK_STATUS_CANCELLED: self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) return self._create_task_error_result("Task was cancelled") # Step 4: Get the actual result for completed tasks (with error handling for race conditions) - if final_status.status == "completed": + if final_status.status == TASK_STATUS_COMPLETED: self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) try: result = await session.experimental.get_task_result(task_id, MCPCallToolResult) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 629163695..75201b5d6 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -38,7 +38,6 @@ def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): """Test that _should_use_task returns False without opt-in even with server/tool support.""" with MCPClient(mock_transport["transport_callable"]) as client: client._server_task_capable = True - client._tool_task_support_cache["test_tool"] = "required" assert client._should_use_task("test_tool") is False with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: From af20bf0110d65d279daec904e34cef6fdb8fc9b7 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 3 Feb 2026 11:43:10 -0800 Subject: [PATCH 06/10] chore: inline trivial task methods --- src/strands/tools/mcp/mcp_client.py | 59 +++++++------------ .../tools/mcp/test_mcp_client_tasks.py | 12 ++-- 2 files changed, 28 insertions(+), 43 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 7d152ea78..9ad987d7c 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -128,6 +128,10 @@ class ExperimentalConfig(TypedDict, total=False): "unknown request id", ] +DEFAULT_TASK_TTL_MS = 60000 +DEFAULT_TASK_POLL_TIMEOUT_SECONDS = 300.0 +DEFAULT_TASK_CONFIG = TasksConfig(ttl_ms=DEFAULT_TASK_TTL_MS, poll_timeout_seconds=DEFAULT_TASK_POLL_TIMEOUT_SECONDS) + class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. @@ -619,7 +623,10 @@ def _create_call_tool_coroutine( if use_task: self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) - poll_timeout = self._convert_timeout_for_polling(read_timeout_seconds) + + # When task-augmented execution is used, the read_timeout_seconds parameter + # (which is a timedelta) needs to be converted to a float for the polling timeout. + poll_timeout = read_timeout_seconds.total_seconds() if read_timeout_seconds else None async def _call_as_task() -> MCPCallToolResult: return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout_seconds=poll_timeout) @@ -1010,27 +1017,13 @@ def _is_tasks_enabled(self) -> bool: """ return self._experimental.get("tasks") is not None - def _get_task_ttl_ms(self) -> int: - """Get task TTL in milliseconds. - - Returns: - Task TTL from config, or default of 60000 (1 minute). - """ - tasks_config = self._experimental.get("tasks") - if tasks_config is None: - return 60000 - return tasks_config.get("ttl_ms", 60000) - - def _get_task_poll_timeout_seconds(self) -> float: - """Get task polling timeout in seconds. - - Returns: - Polling timeout from config, or default of 300.0 (5 minutes). - """ - tasks_config = self._experimental.get("tasks") - if tasks_config is None: - return 300.0 - return tasks_config.get("poll_timeout_seconds", 300.0) + def _get_task_config(self) -> TasksConfig: + """Returns the task execution configuration, configured with defaults if not specified.""" + task_config = self._experimental.get("tasks") or DEFAULT_TASK_CONFIG + return TasksConfig( + ttl_ms=task_config.get("ttl_ms", DEFAULT_TASK_TTL_MS), + poll_timeout_seconds=task_config.get("poll_timeout_seconds", DEFAULT_TASK_POLL_TIMEOUT_SECONDS), + ) def _has_server_task_support(self) -> bool: """Check if the MCP server supports task-augmented tool calls. @@ -1079,20 +1072,6 @@ def _should_use_task(self, tool_name: str) -> bool: # Default: 'forbidden', None, or unknown -> don't use tasks return False - def _convert_timeout_for_polling(self, read_timeout_seconds: timedelta | None) -> float | None: - """Convert a timedelta timeout to seconds for task polling. - - When task-augmented execution is used, the read_timeout_seconds parameter - (which is a timedelta) needs to be converted to a float for the polling timeout. - - Args: - read_timeout_seconds: Optional timedelta timeout from the call_tool API. - - Returns: - Float seconds if timeout was specified, None to use default. - """ - return read_timeout_seconds.total_seconds() if read_timeout_seconds else None - def _create_task_error_result(self, message: str) -> MCPCallToolResult: """Create an error MCPCallToolResult with consistent formatting. @@ -1149,8 +1128,12 @@ async def _call_tool_as_task_and_poll_async( from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult session = cast(ClientSession, self._background_thread_session) - ttl = ttl_ms or self._get_task_ttl_ms() - timeout = poll_timeout_seconds or self._get_task_poll_timeout_seconds() + + # Precedence: arg > config > default + timeout = poll_timeout_seconds or self._get_task_config().get( + "poll_timeout_seconds", DEFAULT_TASK_POLL_TIMEOUT_SECONDS + ) + ttl = ttl_ms or self._get_task_config().get("ttl_ms", DEFAULT_TASK_TTL_MS) # Step 1: Create the task self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 75201b5d6..b633089b1 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -12,6 +12,7 @@ from mcp.types import ToolExecution from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import DEFAULT_TASK_POLL_TIMEOUT_SECONDS, DEFAULT_TASK_TTL_MS from .conftest import create_server_capabilities @@ -52,17 +53,18 @@ class TestTaskConfiguration: @pytest.mark.parametrize( "config,expected_ttl,expected_timeout", [ - ({}, 60000, 300.0), - ({"ttl_ms": 120000}, 120000, 300.0), - ({"poll_timeout_seconds": 60.0}, 60000, 60.0), + ({}, DEFAULT_TASK_TTL_MS, DEFAULT_TASK_POLL_TIMEOUT_SECONDS), + ({"ttl_ms": 120000}, 120000, DEFAULT_TASK_POLL_TIMEOUT_SECONDS), + ({"poll_timeout_seconds": 60.0}, DEFAULT_TASK_TTL_MS, 60.0), ({"ttl_ms": 120000, "poll_timeout_seconds": 60.0}, 120000, 60.0), ], ) def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): """Test task configuration values with various configs.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: - assert client._get_task_ttl_ms() == expected_ttl - assert client._get_task_poll_timeout_seconds() == expected_timeout + config_actual = client._get_task_config() + assert config_actual.get("ttl_ms") == expected_ttl + assert config_actual.get("poll_timeout_seconds") == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" From 814218e55be01ec3be4cb9f151512bef6e40b755 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 3 Feb 2026 11:53:08 -0800 Subject: [PATCH 07/10] chore: use timedelta for task-related durations --- src/strands/tools/mcp/mcp_client.py | 50 +++++++++---------- .../tools/mcp/test_mcp_client_tasks.py | 18 +++---- 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 9ad987d7c..29115733f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -81,13 +81,12 @@ class TasksConfig(TypedDict, total=False): create task -> poll for completion -> get result. Attributes: - ttl_ms: Task time-to-live in milliseconds. Defaults to 60000 (1 minute). - poll_timeout_seconds: Timeout for polling task completion in seconds. - Defaults to 300.0 (5 minutes). + ttl: Task time-to-live. Defaults to 1 minute. + poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. """ - ttl_ms: int - poll_timeout_seconds: float + ttl: timedelta + poll_timeout: timedelta class ExperimentalConfig(TypedDict, total=False): @@ -103,7 +102,7 @@ class ExperimentalConfig(TypedDict, total=False): that support it. """ - tasks: TasksConfig | None + tasks: TasksConfig MIME_TO_FORMAT: dict[str, ImageFormat] = { @@ -128,9 +127,9 @@ class ExperimentalConfig(TypedDict, total=False): "unknown request id", ] -DEFAULT_TASK_TTL_MS = 60000 -DEFAULT_TASK_POLL_TIMEOUT_SECONDS = 300.0 -DEFAULT_TASK_CONFIG = TasksConfig(ttl_ms=DEFAULT_TASK_TTL_MS, poll_timeout_seconds=DEFAULT_TASK_POLL_TIMEOUT_SECONDS) +DEFAULT_TASK_TTL = timedelta(minutes=1) +DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) +DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) class MCPClient(ToolProvider): @@ -624,12 +623,10 @@ def _create_call_tool_coroutine( if use_task: self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) - # When task-augmented execution is used, the read_timeout_seconds parameter - # (which is a timedelta) needs to be converted to a float for the polling timeout. - poll_timeout = read_timeout_seconds.total_seconds() if read_timeout_seconds else None - async def _call_as_task() -> MCPCallToolResult: - return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout_seconds=poll_timeout) + # When task-augmented execution is used, use the read_timeout_seconds parameter + # (which is a timedelta) for the polling timeout. + return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout=read_timeout_seconds) return _call_as_task() else: @@ -1021,8 +1018,8 @@ def _get_task_config(self) -> TasksConfig: """Returns the task execution configuration, configured with defaults if not specified.""" task_config = self._experimental.get("tasks") or DEFAULT_TASK_CONFIG return TasksConfig( - ttl_ms=task_config.get("ttl_ms", DEFAULT_TASK_TTL_MS), - poll_timeout_seconds=task_config.get("poll_timeout_seconds", DEFAULT_TASK_POLL_TIMEOUT_SECONDS), + ttl=task_config.get("ttl", DEFAULT_TASK_TTL), + poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), ) def _has_server_task_support(self) -> bool: @@ -1105,8 +1102,8 @@ async def _call_tool_as_task_and_poll_async( self, name: str, arguments: dict[str, Any] | None = None, - ttl_ms: int | None = None, - poll_timeout_seconds: float | None = None, + ttl: timedelta | None = None, + poll_timeout: timedelta | None = None, ) -> MCPCallToolResult: """Call a tool using task-augmented execution and poll until completion. @@ -1118,8 +1115,8 @@ async def _call_tool_as_task_and_poll_async( Args: name: Name of the tool to call. arguments: Optional arguments to pass to the tool. - ttl_ms: Task time-to-live in milliseconds. Uses configured value if not specified. - poll_timeout_seconds: Timeout for polling in seconds. Uses configured value if not specified. + ttl: Task time-to-live. Uses configured value if not specified. + poll_timeout: Timeout for polling. Uses configured value if not specified. Returns: MCPCallToolResult: The final tool result after task completion. @@ -1130,17 +1127,16 @@ async def _call_tool_as_task_and_poll_async( session = cast(ClientSession, self._background_thread_session) # Precedence: arg > config > default - timeout = poll_timeout_seconds or self._get_task_config().get( - "poll_timeout_seconds", DEFAULT_TASK_POLL_TIMEOUT_SECONDS - ) - ttl = ttl_ms or self._get_task_config().get("ttl_ms", DEFAULT_TASK_TTL_MS) + timeout = poll_timeout or self._get_task_config().get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT) + ttl = ttl or self._get_task_config().get("ttl", DEFAULT_TASK_TTL) + ttl_ms = int(ttl.total_seconds() * 1000) # Step 1: Create the task - self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl_ms) create_result = await session.experimental.call_tool_as_task( name=name, arguments=arguments, - ttl=ttl, + ttl=ttl_ms, ) task_id = create_result.task.taskId self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) @@ -1161,7 +1157,7 @@ async def _poll_until_terminal() -> GetTaskResult | None: return final try: - final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout) + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) except asyncio.TimeoutError: self._log_debug_with_thread( "tool=<%s>, task_id=<%s>, timeout=<%s> | task polling timed out", name, task_id, timeout diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index b633089b1..0598a4bcd 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -12,7 +12,7 @@ from mcp.types import ToolExecution from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_client import DEFAULT_TASK_POLL_TIMEOUT_SECONDS, DEFAULT_TASK_TTL_MS +from strands.tools.mcp.mcp_client import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL from .conftest import create_server_capabilities @@ -53,18 +53,18 @@ class TestTaskConfiguration: @pytest.mark.parametrize( "config,expected_ttl,expected_timeout", [ - ({}, DEFAULT_TASK_TTL_MS, DEFAULT_TASK_POLL_TIMEOUT_SECONDS), - ({"ttl_ms": 120000}, 120000, DEFAULT_TASK_POLL_TIMEOUT_SECONDS), - ({"poll_timeout_seconds": 60.0}, DEFAULT_TASK_TTL_MS, 60.0), - ({"ttl_ms": 120000, "poll_timeout_seconds": 60.0}, 120000, 60.0), + ({}, DEFAULT_TASK_TTL, DEFAULT_TASK_POLL_TIMEOUT), + ({"ttl": 120000}, 120000, DEFAULT_TASK_POLL_TIMEOUT), + ({"poll_timeout": 60.0}, DEFAULT_TASK_TTL, 60.0), + ({"ttl": 120000, "poll_timeout": 60.0}, 120000, 60.0), ], ) def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): """Test task configuration values with various configs.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: config_actual = client._get_task_config() - assert config_actual.get("ttl_ms") == expected_ttl - assert config_actual.get("poll_timeout_seconds") == expected_timeout + assert config_actual.get("ttl") == expected_ttl + assert config_actual.get("poll_timeout") == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" @@ -132,7 +132,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout": timedelta(seconds=0.1)}} ) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) @@ -152,7 +152,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout": timedelta(minutes=5)}} ) as client: client.list_tools_sync() result = await client.call_tool_async( From 4ff7d1cfab3b8589b5dc3f4d0f5e9a51b6ce8689 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 4 Feb 2026 13:40:12 -0800 Subject: [PATCH 08/10] refactor(mcp): simplify tasks API and move TasksConfig to dedicated module - Create mcp_tasks.py module with TasksConfig and task-related constants - Remove ExperimentalConfig wrapper class from mcp_client.py - Change MCPClient parameter from experimental to tasks (direct TasksConfig) - Export TasksConfig at module level in __init__.py - Update all unit and integration tests to use simplified API --- src/strands/tools/mcp/__init__.py | 3 +- src/strands/tools/mcp/mcp_client.py | 60 ++++--------------- src/strands/tools/mcp/mcp_tasks.py | 32 ++++++++++ .../tools/mcp/test_mcp_client_tasks.py | 47 ++++++++------- tests_integ/mcp/test_mcp_client_tasks.py | 15 ++--- 5 files changed, 74 insertions(+), 83 deletions(-) create mode 100644 src/strands/tools/mcp/mcp_tasks.py diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index cfa841c46..8d2c1daa2 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -8,6 +8,7 @@ from .mcp_agent_tool import MCPAgentTool from .mcp_client import MCPClient, ToolFilters +from .mcp_tasks import TasksConfig from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 29115733f..c1c22f787 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -47,6 +47,7 @@ from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation +from .mcp_tasks import DEFAULT_TASK_CONFIG, DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL, TasksConfig from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -73,38 +74,6 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] -class TasksConfig(TypedDict, total=False): - """Configuration for MCP Tasks (task-augmented tool execution). - - If this config is provided (not None), task-augmented execution is enabled. - When enabled, long-running tool calls use the MCP task workflow: - create task -> poll for completion -> get result. - - Attributes: - ttl: Task time-to-live. Defaults to 1 minute. - poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. - """ - - ttl: timedelta - poll_timeout: timedelta - - -class ExperimentalConfig(TypedDict, total=False): - """Configuration for experimental MCPClient features. - - Warning: - Features under this configuration are experimental and subject to change - in future revisions without notice. - - Attributes: - tasks: Configuration for MCP Tasks (task-augmented tool execution). - If provided (not None), enables task-augmented execution for tools - that support it. - """ - - tasks: TasksConfig - - MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -127,10 +96,6 @@ class ExperimentalConfig(TypedDict, total=False): "unknown request id", ] -DEFAULT_TASK_TTL = timedelta(minutes=1) -DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) -DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) - class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. @@ -156,7 +121,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, - experimental: ExperimentalConfig | None = None, + tasks: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -167,10 +132,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. - experimental: Configuration for experimental features. Currently supports: - - tasks: Enable MCP task-augmented execution for long-running tools. - If provided (not None), enables task-augmented execution for tools - that support it. See ExperimentalConfig and TasksConfig for details. + tasks: Configuration for MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools that support it. + See TasksConfig for details. This feature is experimental and subject to change. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -196,7 +160,7 @@ def __init__( self._consumers: set[Any] = set() # Task support configuration and caching - self._experimental = experimental or {} + self._tasks = tasks self._server_task_capable: bool | None = None # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) @@ -1005,18 +969,18 @@ def _is_session_active(self) -> bool: return True def _is_tasks_enabled(self) -> bool: - """Check if experimental tasks feature is enabled. + """Check if tasks feature is enabled. - Tasks are enabled if experimental.tasks is defined and not None. + Tasks are enabled if tasks config is defined and not None. Returns: True if task-augmented execution is enabled, False otherwise. """ - return self._experimental.get("tasks") is not None + return self._tasks is not None def _get_task_config(self) -> TasksConfig: """Returns the task execution configuration, configured with defaults if not specified.""" - task_config = self._experimental.get("tasks") or DEFAULT_TASK_CONFIG + task_config = self._tasks or DEFAULT_TASK_CONFIG return TasksConfig( ttl=task_config.get("ttl", DEFAULT_TASK_TTL), poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), @@ -1038,7 +1002,7 @@ def _should_use_task(self, tool_name: str) -> bool: """Determine if task-augmented execution should be used for a tool. Task-augmented execution requires: - 1. experimental.tasks is enabled (opt-in check) + 1. tasks config is enabled (opt-in check) 2. Server supports tasks (capability check) 3. Tool taskSupport is 'required' or 'optional' @@ -1048,7 +1012,7 @@ def _should_use_task(self, tool_name: str) -> bool: Returns: True if task-augmented execution should be used, False otherwise. """ - # Opt-in check: tasks must be explicitly enabled via experimental.tasks + # Opt-in check: tasks must be explicitly enabled via tasks config if not self._is_tasks_enabled(): return False diff --git a/src/strands/tools/mcp/mcp_tasks.py b/src/strands/tools/mcp/mcp_tasks.py new file mode 100644 index 000000000..42aee387d --- /dev/null +++ b/src/strands/tools/mcp/mcp_tasks.py @@ -0,0 +1,32 @@ +"""Task-augmented tool execution configuration for MCP. + +This module provides configuration types and defaults for the experimental MCP Tasks feature. +""" + +from datetime import timedelta + +from typing_extensions import TypedDict + + +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + If this config is provided (not None), task-augmented execution is enabled. + When enabled, supported tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Warning: + This feature is experimental and subject to change in future revisions without notice. + + Attributes: + ttl: Task time-to-live. Defaults to 1 minute. + poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. + """ + + ttl: timedelta + poll_timeout: timedelta + + +DEFAULT_TASK_TTL = timedelta(minutes=1) +DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) +DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 0598a4bcd..4072daeaa 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -11,28 +11,25 @@ from mcp.types import Tool as MCPTool from mcp.types import ToolExecution -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_client import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL +from strands.tools.mcp import MCPClient, TasksConfig +from strands.tools.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL from .conftest import create_server_capabilities class TestTasksOptIn: - """Tests for task opt-in behavior via experimental.tasks.""" + """Tests for task opt-in behavior via tasks config.""" @pytest.mark.parametrize( - "experimental,expected_enabled", + "tasks,expected_enabled", [ (None, False), - ({}, False), - ({"tasks": None}, False), - ({"tasks": {}}, True), - ({"tasks": {"ttl_ms": 1000}}, True), + ({}, True), ], ) - def test_tasks_enabled_state(self, mock_transport, mock_session, experimental, expected_enabled): - """Test _is_tasks_enabled based on experimental config.""" - with MCPClient(mock_transport["transport_callable"], experimental=experimental) as client: + def test_tasks_enabled_state(self, mock_transport, mock_session, tasks, expected_enabled): + """Test _is_tasks_enabled based on tasks config.""" + with MCPClient(mock_transport["transport_callable"], tasks=tasks) as client: assert client._is_tasks_enabled() is expected_enabled def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): @@ -41,7 +38,7 @@ def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): client._server_task_capable = True assert client._should_use_task("test_tool") is False - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks={}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" assert client._should_use_task("test_tool") is True @@ -54,21 +51,25 @@ class TestTaskConfiguration: "config,expected_ttl,expected_timeout", [ ({}, DEFAULT_TASK_TTL, DEFAULT_TASK_POLL_TIMEOUT), - ({"ttl": 120000}, 120000, DEFAULT_TASK_POLL_TIMEOUT), - ({"poll_timeout": 60.0}, DEFAULT_TASK_TTL, 60.0), - ({"ttl": 120000, "poll_timeout": 60.0}, 120000, 60.0), + ({"ttl": timedelta(seconds=120)}, timedelta(seconds=120), DEFAULT_TASK_POLL_TIMEOUT), + ({"poll_timeout": timedelta(seconds=60)}, DEFAULT_TASK_TTL, timedelta(seconds=60)), + ( + {"ttl": timedelta(seconds=120), "poll_timeout": timedelta(seconds=60)}, + timedelta(seconds=120), + timedelta(seconds=60), + ), ], ) def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): """Test task configuration values with various configs.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: + with MCPClient(mock_transport["transport_callable"], tasks=config) as client: config_actual = client._get_task_config() assert config_actual.get("ttl") == expected_ttl assert config_actual.get("poll_timeout") == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks={}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" assert client._server_task_capable is None @@ -112,7 +113,7 @@ async def mock_poll_task(task_id): mock_session.experimental.poll_task = mock_poll_task - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) @@ -132,7 +133,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout": timedelta(seconds=0.1)}} + mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(seconds=0.1)) ) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) @@ -152,7 +153,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout": timedelta(minutes=5)}} + mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(minutes=5)) ) as client: client.list_tools_sync() result = await client.call_tool_async( @@ -172,7 +173,7 @@ async def successful_poll(task_id): mock_session.experimental.poll_task = successful_poll mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) assert result["status"] == "error" @@ -189,7 +190,7 @@ async def empty_poll(task_id): mock_session.experimental.poll_task = empty_poll - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" @@ -208,7 +209,7 @@ async def poll(task_id): return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) ) - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) assert result["status"] == "success" diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index 5e398f6de..1b453d6e4 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -9,8 +9,7 @@ import pytest from mcp.client.streamable_http import streamablehttp_client -from strands.tools.mcp.mcp_client import MCPClient -from strands.tools.mcp.mcp_types import MCPTransport +from strands.tools.mcp import MCPClient, MCPTransport, TasksConfig def _find_available_port() -> int: @@ -52,7 +51,7 @@ def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback, experimental={"tasks": {}}) + return MCPClient(transport_callback, tasks=TasksConfig()) @pytest.fixture @@ -105,17 +104,11 @@ def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: assert r2["status"] == "success" assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") - def test_task_support_caching_and_decision(self, task_mcp_client: MCPClient) -> None: - """Test taskSupport caching and _should_use_task decision logic.""" + def test_task_support_tool_detection(self, task_mcp_client: MCPClient) -> None: + """Test tool-level task support detection.""" with task_mcp_client: task_mcp_client.list_tools_sync() - # Verify cached values - assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" - assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" - assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" - assert task_mcp_client._get_tool_task_support("echo") is None - # Verify decision logic assert task_mcp_client._should_use_task("task_required_echo") is True assert task_mcp_client._should_use_task("task_optional_echo") is True From 7e1180381c85b6c85539a831f8b8d1640b98b16d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 4 Feb 2026 13:45:46 -0800 Subject: [PATCH 09/10] chore: tweak TasksConfig docs --- src/strands/tools/mcp/mcp_tasks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_tasks.py b/src/strands/tools/mcp/mcp_tasks.py index 42aee387d..36537f7df 100644 --- a/src/strands/tools/mcp/mcp_tasks.py +++ b/src/strands/tools/mcp/mcp_tasks.py @@ -11,12 +11,13 @@ class TasksConfig(TypedDict, total=False): """Configuration for MCP Tasks (task-augmented tool execution). - If this config is provided (not None), task-augmented execution is enabled. When enabled, supported tool calls use the MCP task workflow: create task -> poll for completion -> get result. Warning: - This feature is experimental and subject to change in future revisions without notice. + This is an experimental feature in the 2025-11-25 MCP specification and + both the specification and the Strands Agents implementation of this + feature are subject to change. Attributes: ttl: Task time-to-live. Defaults to 1 minute. From 490f7086d93809d2e1e622d978d5351c919d3856 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 9 Feb 2026 14:49:48 -0800 Subject: [PATCH 10/10] refactor(mcp): rename tasks to tasks_config - Rename MCPClient parameter for clarity - Add MCP Tasks documentation to AGENTS.md - Fix timeout logging to use total_seconds() --- AGENTS.md | 55 +++++++++++++++++++ src/strands/tools/mcp/mcp_client.py | 19 ++++--- .../tools/mcp/test_mcp_client_tasks.py | 24 ++++---- tests_integ/mcp/test_mcp_client_tasks.py | 2 +- 4 files changed, 80 insertions(+), 20 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 8b4394cc5..aa9beed1a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -69,6 +69,7 @@ strands-agents/ │ │ │ ├── mcp_client.py # MCP client implementation │ │ │ ├── mcp_agent_tool.py # MCP tool wrapper │ │ │ ├── mcp_types.py # MCP type definitions +│ │ │ ├── mcp_tasks.py # Task-augmented execution config │ │ │ └── mcp_instrumentation.py # MCP telemetry │ │ └── structured_output/ # Structured output handling │ │ ├── structured_output_tool.py @@ -408,6 +409,60 @@ hatch test --all # Test all Python versions (3.10-3.13) - Use `pytest.mark.asyncio` for async tests - Keep tests focused and independent +## MCP Tasks (Experimental) + +The SDK supports MCP task-augmented execution for long-running tools. This feature is experimental and aligns with the MCP specification 2025-11-25. + +### Overview + +Task-augmented execution allows tools to run asynchronously with a workflow: +1. Create task via `call_tool_as_task` +2. Poll for completion via `poll_task` +3. Get result via `get_task_result` + +### Configuration + +Enable tasks by passing a `TasksConfig` to `MCPClient`: + +```python +from datetime import timedelta +from strands.tools.mcp import MCPClient, TasksConfig + +# Enable with defaults (ttl=1min, poll_timeout=5min) +client = MCPClient(transport, tasks_config={}) + +# Or configure explicitly +client = MCPClient( + transport, + tasks_config=TasksConfig( + ttl=timedelta(minutes=2), # Task time-to-live + poll_timeout=timedelta(minutes=10), # Polling timeout + ), +) +``` + +### Tool Support Levels + +MCP tools declare their task support via `execution.taskSupport`: +- `TASK_REQUIRED`: Tool must use task-augmented execution +- `TASK_OPTIONAL`: Tool can use tasks if client opts in +- `TASK_FORBIDDEN`: Tool does not support tasks (default) + +### Decision Logic + +Task-augmented execution is used when ALL conditions are met: +1. Client opts in via `tasks_config` (not None) +2. Server advertises task capability (`tasks.requests.tools.call`) +3. Tool's `taskSupport` is `required` or `optional` + +### Key Files + +- `src/strands/tools/mcp/mcp_tasks.py` - `TasksConfig` and defaults +- `src/strands/tools/mcp/mcp_client.py` - Task execution logic (`_call_tool_as_task_and_poll_async`) +- `tests/strands/tools/mcp/test_mcp_client_tasks.py` - Unit tests +- `tests_integ/mcp/test_mcp_client_tasks.py` - Integration tests +- `tests_integ/mcp/task_echo_server.py` - Test server with task support + ## Things to Do - Use explicit return types for all functions diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c1c22f787..f64774477 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -121,7 +121,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, - tasks: TasksConfig | None = None, + tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,7 +132,7 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. - tasks: Configuration for MCP task-augmented execution for long-running tools. + tasks_config: Configuration for MCP task-augmented execution for long-running tools. If provided (not None), enables task-augmented execution for tools that support it. See TasksConfig for details. This feature is experimental and subject to change. """ @@ -160,7 +160,7 @@ def __init__( self._consumers: set[Any] = set() # Task support configuration and caching - self._tasks = tasks + self._tasks_config = tasks_config self._server_task_capable: bool | None = None # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) @@ -976,11 +976,11 @@ def _is_tasks_enabled(self) -> bool: Returns: True if task-augmented execution is enabled, False otherwise. """ - return self._tasks is not None + return self._tasks_config is not None def _get_task_config(self) -> TasksConfig: """Returns the task execution configuration, configured with defaults if not specified.""" - task_config = self._tasks or DEFAULT_TASK_CONFIG + task_config = self._tasks_config or DEFAULT_TASK_CONFIG return TasksConfig( ttl=task_config.get("ttl", DEFAULT_TASK_TTL), poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), @@ -1124,9 +1124,14 @@ async def _poll_until_terminal() -> GetTaskResult | None: final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) except asyncio.TimeoutError: self._log_debug_with_thread( - "tool=<%s>, task_id=<%s>, timeout=<%s> | task polling timed out", name, task_id, timeout + "tool=<%s>, task_id=<%s>, timeout_seconds=<%s> | task polling timed out", + name, + task_id, + timeout.total_seconds(), + ) + return self._create_task_error_result( + f"Task {task_id} polling timed out after {timeout.total_seconds()} seconds" ) - return self._create_task_error_result(f"Task {task_id} polling timed out after {timeout} seconds") # Step 3: Handle terminal status if final_status is None: diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 4072daeaa..01d3b2763 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -21,15 +21,15 @@ class TestTasksOptIn: """Tests for task opt-in behavior via tasks config.""" @pytest.mark.parametrize( - "tasks,expected_enabled", + "tasks_config,expected_enabled", [ (None, False), ({}, True), ], ) - def test_tasks_enabled_state(self, mock_transport, mock_session, tasks, expected_enabled): + def test_tasks_enabled_state(self, mock_transport, mock_session, tasks_config, expected_enabled): """Test _is_tasks_enabled based on tasks config.""" - with MCPClient(mock_transport["transport_callable"], tasks=tasks) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=tasks_config) as client: assert client._is_tasks_enabled() is expected_enabled def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): @@ -38,7 +38,7 @@ def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): client._server_task_capable = True assert client._should_use_task("test_tool") is False - with MCPClient(mock_transport["transport_callable"], tasks={}) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" assert client._should_use_task("test_tool") is True @@ -62,14 +62,14 @@ class TestTaskConfiguration: ) def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): """Test task configuration values with various configs.""" - with MCPClient(mock_transport["transport_callable"], tasks=config) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=config) as client: config_actual = client._get_task_config() assert config_actual.get("ttl") == expected_ttl assert config_actual.get("poll_timeout") == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" - with MCPClient(mock_transport["transport_callable"], tasks={}) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" assert client._server_task_capable is None @@ -113,7 +113,7 @@ async def mock_poll_task(task_id): mock_session.experimental.poll_task = mock_poll_task - with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) @@ -133,7 +133,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(seconds=0.1)) + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(seconds=0.1)) ) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) @@ -153,7 +153,7 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll with MCPClient( - mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(minutes=5)) + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(minutes=5)) ) as client: client.list_tools_sync() result = await client.call_tool_async( @@ -173,7 +173,7 @@ async def successful_poll(task_id): mock_session.experimental.poll_task = successful_poll mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) assert result["status"] == "error" @@ -190,7 +190,7 @@ async def empty_poll(task_id): mock_session.experimental.poll_task = empty_poll - with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" @@ -209,7 +209,7 @@ async def poll(task_id): return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) ) - with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client: + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) assert result["status"] == "success" diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index 1b453d6e4..b2623c6a1 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -51,7 +51,7 @@ def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback, tasks=TasksConfig()) + return MCPClient(transport_callback, tasks_config=TasksConfig()) @pytest.fixture