Skip to content
3 changes: 2 additions & 1 deletion src/strands/tools/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .mcp_agent_tool import MCPAgentTool
from .mcp_client import MCPClient, ToolFilters
from .mcp_tasks import TasksConfig
from .mcp_types import MCPTransport

__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"]
__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"]
303 changes: 285 additions & 18 deletions src/strands/tools/mcp/mcp_client.py

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions src/strands/tools/mcp/mcp_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Task-augmented tool execution configuration for MCP.

This module provides configuration types and defaults for the experimental MCP Tasks feature.
"""

from datetime import timedelta

from typing_extensions import TypedDict


class TasksConfig(TypedDict, total=False):
"""Configuration for MCP Tasks (task-augmented tool execution).

When enabled, supported tool calls use the MCP task workflow:
create task -> poll for completion -> get result.

Warning:
This is an experimental feature in the 2025-11-25 MCP specification and
both the specification and the Strands Agents implementation of this
feature are subject to change.

Attributes:
ttl: Task time-to-live. Defaults to 1 minute.
poll_timeout: Timeout for polling task completion. Defaults to 5 minutes.
"""

ttl: timedelta
poll_timeout: timedelta


DEFAULT_TASK_TTL = timedelta(minutes=1)
DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5)
DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT)
59 changes: 59 additions & 0 deletions tests/strands/tools/mcp/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Shared fixtures and helpers for MCP client tests."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest


@pytest.fixture
def mock_transport():
"""Create a mock MCP transport."""
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
mock_transport_cm = AsyncMock()
mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream)
mock_transport_callable = MagicMock(return_value=mock_transport_cm)

return {
"read_stream": mock_read_stream,
"write_stream": mock_write_stream,
"transport_cm": mock_transport_cm,
"transport_callable": mock_transport_callable,
}


@pytest.fixture
def mock_session():
"""Create a mock MCP session."""
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
# Default: no task support (get_server_capabilities is sync, not async!)
mock_session.get_server_capabilities = MagicMock(return_value=None)

# Create a mock context manager for ClientSession
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__.return_value = mock_session

# Patch ClientSession to return our mock session
with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm):
yield mock_session


def create_server_capabilities(has_task_support: bool) -> MagicMock:
"""Create mock server capabilities.

Args:
has_task_support: Whether the server should advertise task support.

Returns:
MagicMock representing server capabilities.
"""
caps = MagicMock()
if has_task_support:
caps.tasks = MagicMock()
caps.tasks.requests = MagicMock()
caps.tasks.requests.tools = MagicMock()
caps.tasks.requests.tools.call = MagicMock()
else:
caps.tasks = None
return caps
32 changes: 2 additions & 30 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import time
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch

import pytest
from mcp import ListToolsResult
Expand All @@ -25,35 +25,7 @@
from strands.tools.mcp.mcp_types import MCPToolResult
from strands.types.exceptions import MCPClientInitializationError


@pytest.fixture
def mock_transport():
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
mock_transport_cm = AsyncMock()
mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream)
mock_transport_callable = MagicMock(return_value=mock_transport_cm)

return {
"read_stream": mock_read_stream,
"write_stream": mock_write_stream,
"transport_cm": mock_transport_cm,
"transport_callable": mock_transport_callable,
}


@pytest.fixture
def mock_session():
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()

# Create a mock context manager for ClientSession
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__.return_value = mock_session

# Patch ClientSession to return our mock session
with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm):
yield mock_session
# Fixtures mock_transport and mock_session are imported from conftest.py


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client_contextvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def mock_session():
"""Create mock MCP session."""
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
# get_server_capabilities is sync, not async
mock_session.get_server_capabilities = MagicMock(return_value=None)

mock_session_cm = AsyncMock()
mock_session_cm.__aenter__.return_value = mock_session
Expand Down
216 changes: 216 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""Tests for MCP task-augmented execution support in MCPClient."""

import asyncio
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock

import pytest
from mcp import ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool
from mcp.types import ToolExecution

from strands.tools.mcp import MCPClient, TasksConfig
from strands.tools.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL

from .conftest import create_server_capabilities


class TestTasksOptIn:
"""Tests for task opt-in behavior via tasks config."""

@pytest.mark.parametrize(
"tasks,expected_enabled",
[
(None, False),
({}, True),
],
)
def test_tasks_enabled_state(self, mock_transport, mock_session, tasks, expected_enabled):
"""Test _is_tasks_enabled based on tasks config."""
with MCPClient(mock_transport["transport_callable"], tasks=tasks) as client:
assert client._is_tasks_enabled() is expected_enabled

def test_should_use_task_requires_opt_in(self, mock_transport, mock_session):
"""Test that _should_use_task returns False without opt-in even with server/tool support."""
with MCPClient(mock_transport["transport_callable"]) as client:
client._server_task_capable = True
assert client._should_use_task("test_tool") is False

with MCPClient(mock_transport["transport_callable"], tasks={}) as client:
client._server_task_capable = True
client._tool_task_support_cache["test_tool"] = "required"
assert client._should_use_task("test_tool") is True


class TestTaskConfiguration:
"""Tests for task-related configuration options."""

@pytest.mark.parametrize(
"config,expected_ttl,expected_timeout",
[
({}, DEFAULT_TASK_TTL, DEFAULT_TASK_POLL_TIMEOUT),
({"ttl": timedelta(seconds=120)}, timedelta(seconds=120), DEFAULT_TASK_POLL_TIMEOUT),
({"poll_timeout": timedelta(seconds=60)}, DEFAULT_TASK_TTL, timedelta(seconds=60)),
(
{"ttl": timedelta(seconds=120), "poll_timeout": timedelta(seconds=60)},
timedelta(seconds=120),
timedelta(seconds=60),
),
],
)
def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout):
"""Test task configuration values with various configs."""
with MCPClient(mock_transport["transport_callable"], tasks=config) as client:
config_actual = client._get_task_config()
assert config_actual.get("ttl") == expected_ttl
assert config_actual.get("poll_timeout") == expected_timeout

def test_stop_resets_task_caches(self, mock_transport, mock_session):
"""Test that stop() resets the task support caches."""
with MCPClient(mock_transport["transport_callable"], tasks={}) as client:
client._server_task_capable = True
client._tool_task_support_cache["tool1"] = "required"
assert client._server_task_capable is None
assert client._tool_task_support_cache == {}


class TestTaskExecution:
"""Tests for task execution and error handling."""

def _setup_task_tool(self, mock_session, tool_name: str) -> None:
"""Helper to set up a mock task-enabled tool."""
mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True))
mock_tool = MCPTool(
name=tool_name,
description="A test tool",
inputSchema={"type": "object"},
execution=ToolExecution(taskSupport="optional"),
)
mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None))
mock_create_result = MagicMock()
mock_create_result.task.taskId = "test-task-id"
mock_session.experimental = MagicMock()
mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result)

@pytest.mark.parametrize(
"status,status_message,expected_text",
[
("failed", "Something went wrong", "Something went wrong"),
("cancelled", None, "cancelled"),
("unknown_status", None, "unexpected task status"),
],
)
def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text):
"""Test handling of terminal task statuses."""
mock_create_result = MagicMock()
mock_create_result.task.taskId = f"task-{status}"
mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result)

async def mock_poll_task(task_id):
yield MagicMock(status=status, statusMessage=status_message)

mock_session.experimental.poll_task = mock_poll_task

with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client:
client._server_task_capable = True
client._tool_task_support_cache["test_tool"] = "required"
result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={})
assert result["status"] == "error"
assert expected_text.lower() in result["content"][0].get("text", "").lower()

@pytest.mark.asyncio
async def test_polling_timeout(self, mock_transport, mock_session):
"""Test that task polling times out properly."""
self._setup_task_tool(mock_session, "slow_tool")

async def infinite_poll(task_id):
while True:
await asyncio.sleep(1)
yield MagicMock(status="running")

mock_session.experimental.poll_task = infinite_poll

with MCPClient(
mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(seconds=0.1))
) as client:
client.list_tools_sync()
result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={})
assert result["status"] == "error"
assert "timed out" in result["content"][0].get("text", "").lower()

@pytest.mark.asyncio
async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session):
"""Test that read_timeout_seconds overrides the default poll timeout."""
self._setup_task_tool(mock_session, "timeout_tool")

async def infinite_poll(task_id):
while True:
await asyncio.sleep(1)
yield MagicMock(status="running")

mock_session.experimental.poll_task = infinite_poll

with MCPClient(
mock_transport["transport_callable"], tasks=TasksConfig(poll_timeout=timedelta(minutes=5))
) as client:
client.list_tools_sync()
result = await client.call_tool_async(
tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1)
)
assert result["status"] == "error"
assert "timed out" in result["content"][0].get("text", "").lower()

@pytest.mark.asyncio
async def test_result_retrieval_failure(self, mock_transport, mock_session):
"""Test that get_task_result failures are handled gracefully."""
self._setup_task_tool(mock_session, "failing_tool")

async def successful_poll(task_id):
yield MagicMock(status="completed", statusMessage=None)

mock_session.experimental.poll_task = successful_poll
mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error"))

with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client:
client.list_tools_sync()
result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={})
assert result["status"] == "error"
assert "result retrieval failed" in result["content"][0].get("text", "").lower()

@pytest.mark.asyncio
async def test_empty_poll_result(self, mock_transport, mock_session):
"""Test handling when poll_task yields nothing."""
self._setup_task_tool(mock_session, "empty_poll_tool")

async def empty_poll(task_id):
return
yield # noqa: B901

mock_session.experimental.poll_task = empty_poll

with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client:
client.list_tools_sync()
result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={})
assert result["status"] == "error"
assert "without status" in result["content"][0].get("text", "").lower()

@pytest.mark.asyncio
async def test_successful_completion(self, mock_transport, mock_session):
"""Test successful task completion."""
self._setup_task_tool(mock_session, "success_tool")

async def poll(task_id):
yield MagicMock(status="completed", statusMessage=None)

mock_session.experimental.poll_task = poll
mock_session.experimental.get_task_result = AsyncMock(
return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False)
)

with MCPClient(mock_transport["transport_callable"], tasks=TasksConfig()) as client:
client.list_tools_sync()
result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={})
assert result["status"] == "success"
assert "Done" in result["content"][0].get("text", "")
Loading
Loading