diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 4415467cd7..8e7fdc4a1d 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -15,6 +15,7 @@ that the integration properly instruments MCP handlers with Sentry spans. """ +import anyio import pytest import json from unittest import mock @@ -30,6 +31,15 @@ async def __call__(self, *args, **kwargs): from mcp.server.lowlevel import Server from mcp.server.lowlevel.server import request_ctx +from mcp.types import ( + JSONRPCMessage, + JSONRPCRequest, + GetPromptResult, + PromptMessage, + TextContent, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.shared.message import SessionMessage try: from mcp.server.lowlevel.server import request_ctx @@ -41,6 +51,70 @@ async def __call__(self, *args, **kwargs): from sentry_sdk.integrations.mcp import MCPIntegration +def get_initialization_payload(request_id: str): + return SessionMessage( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + ) + ) + ) + + +def get_mcp_command_payload(method: str, params, request_id: str): + return SessionMessage( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method=method, + params=params, + ) + ) + ) + + +async def stdio(server, method: str, params, request_id: str | None = None): + if request_id is None: + request_id = "1" # arbitrary + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + result = {} + + async def run_server(): + await server.run( + read_stream, write_stream, server.create_initialization_options() + ) + + async def simulate_client(tg, result): + init_request = get_initialization_payload("1") + await read_stream_writer.send(init_request) + + await write_stream_reader.receive() + + request = get_mcp_command_payload(method, params=params, request_id=request_id) + await read_stream_writer.send(request) + + result["response"] = await write_stream_reader.receive() + + tg.cancel_scope.cancel() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(simulate_client, tg, result) + + return result["response"] + + @pytest.fixture(autouse=True) def reset_request_ctx(): """Reset request context before and after each test""" @@ -141,11 +215,12 @@ def test_integration_patches_server(sentry_init): assert Server.read_resource is not original_read_resource +@pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_tool_handler_sync( +async def test_tool_handler_stdio( sentry_init, capture_events, send_default_pii, include_prompts ): """Test that synchronous tool handlers create proper spans""" @@ -158,19 +233,25 @@ def test_tool_handler_sync( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-123", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool(tool_name, arguments): + async def test_tool(tool_name, arguments): return {"result": "success", "value": 42} with start_transaction(name="mcp tx"): - # Call the tool handler - result = test_tool("calculate", {"x": 10, "y": 5}) + result = await stdio( + server, + method="tools/call", + params={ + "name": "calculate", + "arguments": {"x": 10, "y": 5}, + }, + request_id="req-123", + ) - assert result == {"result": "success", "value": 42} + assert result.message.root.result["content"][0]["text"] == json.dumps( + {"result": "success", "value": 42}, + indent=2, + ) (tx,) = events assert tx["type"] == "transaction" @@ -262,7 +343,8 @@ async def test_tool_async(tool_name, arguments): assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"] -def test_tool_handler_with_error(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_tool_handler_with_error(sentry_init, capture_events): """Test that tool handler errors are captured properly""" sentry_init( integrations=[MCPIntegration()], @@ -272,17 +354,23 @@ def test_tool_handler_with_error(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def failing_tool(tool_name, arguments): + async def failing_tool(tool_name, arguments): raise ValueError("Tool execution failed") with start_transaction(name="mcp tx"): - with pytest.raises(ValueError): - failing_tool("bad_tool", {}) + result = await stdio( + server, + method="tools/call", + params={ + "name": "bad_tool", + "arguments": {}, + }, + ) + + assert ( + result.message.root.result["content"][0]["text"] == "Tool execution failed" + ) # Should have error event and transaction assert len(events) == 2 @@ -304,11 +392,12 @@ def failing_tool(tool_name, arguments): assert span["tags"]["status"] == "internal_error" +@pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_prompt_handler_sync( +async def test_prompt_handler_sync( sentry_init, capture_events, send_default_pii, include_prompts ): """Test that synchronous prompt handlers create proper spans""" @@ -321,19 +410,34 @@ def test_prompt_handler_sync( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-prompt", transport="stdio") - request_ctx.set(mock_ctx) - @server.get_prompt() - def test_prompt(name, arguments): - return MockGetPromptResult([MockPromptMessage("user", "Tell me about Python")]) + async def test_prompt(name, arguments): + return GetPromptResult( + description="A helpful test prompt", + messages=[ + PromptMessage( + role="user", + content=TextContent(type="text", text="Tell me about Python"), + ), + ], + ) with start_transaction(name="mcp tx"): - result = test_prompt("code_help", {"language": "python"}) + result = await stdio( + server, + method="prompts/get", + params={ + "name": "code_help", + "arguments": {"language": "python"}, + }, + request_id="req-prompt", + ) - assert result.messages[0].role == "user" - assert result.messages[0].content.text == "Tell me about Python" + assert result.message.root.result["messages"][0]["role"] == "user" + assert ( + result.message.root.result["messages"][0]["content"]["text"] + == "Tell me about Python" + ) (tx,) = events assert tx["type"] == "transaction" @@ -420,7 +524,8 @@ async def test_prompt_async(name, arguments): assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"] -def test_prompt_handler_with_error(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_prompt_handler_with_error(sentry_init, capture_events): """Test that prompt handler errors are captured""" sentry_init( integrations=[MCPIntegration()], @@ -430,17 +535,22 @@ def test_prompt_handler_with_error(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error-prompt", transport="stdio") - request_ctx.set(mock_ctx) - @server.get_prompt() - def failing_prompt(name, arguments): + async def failing_prompt(name, arguments): raise RuntimeError("Prompt not found") with start_transaction(name="mcp tx"): - with pytest.raises(RuntimeError): - failing_prompt("missing_prompt", {}) + response = await stdio( + server, + method="prompts/get", + params={ + "name": "code_help", + "arguments": {"language": "python"}, + }, + request_id="req-prompt", + ) + + assert response.message.root.error.message == "Prompt not found" # Should have error event and transaction assert len(events) == 2 @@ -450,7 +560,8 @@ def failing_prompt(name, arguments): assert error_event["exception"]["values"][0]["type"] == "RuntimeError" -def test_resource_handler_sync(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_resource_handler_sync(sentry_init, capture_events): """Test that synchronous resource handlers create proper spans""" sentry_init( integrations=[MCPIntegration()], @@ -460,19 +571,27 @@ def test_resource_handler_sync(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-resource", transport="stdio") - request_ctx.set(mock_ctx) - @server.read_resource() - def test_resource(uri): - return {"content": "file contents", "mime_type": "text/plain"} + async def test_resource(uri): + return [ + ReadResourceContents( + content=json.dumps({"content": "file contents"}), mime_type="text/plain" + ) + ] with start_transaction(name="mcp tx"): - uri = MockURI("file:///path/to/file.txt") - result = test_resource(uri) + result = await stdio( + server, + method="resources/read", + params={ + "uri": "file:///path/to/file.txt", + }, + request_id="req-resource", + ) - assert result["content"] == "file contents" + assert result.message.root.result["contents"][0]["text"] == json.dumps( + {"content": "file contents"}, + ) (tx,) = events assert tx["type"] == "transaction" @@ -533,7 +652,8 @@ async def test_resource_async(uri): assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-res" -def test_resource_handler_with_error(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_resource_handler_with_error(sentry_init, capture_events): """Test that resource handler errors are captured""" sentry_init( integrations=[MCPIntegration()], @@ -543,18 +663,18 @@ def test_resource_handler_with_error(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error-resource", transport="stdio") - request_ctx.set(mock_ctx) - @server.read_resource() - def failing_resource(uri): + async def failing_resource(uri): raise FileNotFoundError("Resource not found") with start_transaction(name="mcp tx"): - with pytest.raises(FileNotFoundError): - uri = MockURI("file:///missing.txt") - failing_resource(uri) + await stdio( + server, + method="resources/read", + params={ + "uri": "file:///missing.txt", + }, + ) # Should have error event and transaction assert len(events) == 2 @@ -564,11 +684,12 @@ def failing_resource(uri): assert error_event["exception"]["values"][0]["type"] == "FileNotFoundError" +@pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (False, False)], ) -def test_tool_result_extraction_tuple( +async def test_tool_result_extraction_tuple( sentry_init, capture_events, send_default_pii, include_prompts ): """Test extraction of tool results from tuple format (UnstructuredContent, StructuredContent)""" @@ -581,19 +702,22 @@ def test_tool_result_extraction_tuple( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-tuple", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool_tuple(tool_name, arguments): + async def test_tool_tuple(tool_name, arguments): # Return CombinationContent: (UnstructuredContent, StructuredContent) unstructured = [MockTextContent("Result text")] structured = {"key": "value", "count": 5} return (unstructured, structured) with start_transaction(name="mcp tx"): - test_tool_tuple("combo_tool", {}) + await stdio( + server, + method="tools/call", + params={ + "name": "calculate", + "arguments": {"x": 10, "y": 5}, + }, + ) (tx,) = events span = tx["spans"][0] @@ -612,11 +736,12 @@ def test_tool_tuple(tool_name, arguments): assert SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT not in span["data"] +@pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (False, False)], ) -def test_tool_result_extraction_unstructured( +async def test_tool_result_extraction_unstructured( sentry_init, capture_events, send_default_pii, include_prompts ): """Test extraction of tool results from UnstructuredContent (list of content blocks)""" @@ -629,12 +754,8 @@ def test_tool_result_extraction_unstructured( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-unstructured", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool_unstructured(tool_name, arguments): + async def test_tool_unstructured(tool_name, arguments): # Return UnstructuredContent as list of content blocks return [ MockTextContent("First part"), @@ -642,7 +763,14 @@ def test_tool_unstructured(tool_name, arguments): ] with start_transaction(name="mcp tx"): - test_tool_unstructured("text_tool", {}) + await stdio( + server, + method="tools/call", + params={ + "name": "calculate", + "arguments": {"x": 10, "y": 5}, + }, + ) (tx,) = events span = tx["spans"][0] @@ -693,7 +821,8 @@ def test_tool_no_ctx(tool_name, arguments): assert SPANDATA.MCP_SESSION_ID not in span["data"] -def test_span_origin(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_span_origin(sentry_init, capture_events): """Test that span origin is set correctly""" sentry_init( integrations=[MCPIntegration()], @@ -703,16 +832,19 @@ def test_span_origin(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-origin", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool(tool_name, arguments): + async def test_tool(tool_name, arguments): return {"result": "test"} with start_transaction(name="mcp tx"): - test_tool("origin_test", {}) + await stdio( + server, + method="tools/call", + params={ + "name": "calculate", + "arguments": {"x": 10, "y": 5}, + }, + ) (tx,) = events @@ -720,7 +852,8 @@ def test_tool(tool_name, arguments): assert tx["spans"][0]["origin"] == "auto.ai.mcp" -def test_multiple_handlers(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_multiple_handlers(sentry_init, capture_events): """Test that multiple handler calls create multiple spans""" sentry_init( integrations=[MCPIntegration()], @@ -730,26 +863,52 @@ def test_multiple_handlers(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-multi", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def tool1(tool_name, arguments): + async def tool1(tool_name, arguments): return {"result": "tool1"} @server.call_tool() - def tool2(tool_name, arguments): + async def tool2(tool_name, arguments): return {"result": "tool2"} @server.get_prompt() - def prompt1(name, arguments): - return MockGetPromptResult([MockPromptMessage("user", "Test prompt")]) + async def prompt1(name, arguments): + return GetPromptResult( + description="A test prompt", + messages=[ + PromptMessage( + role="user", content=TextContent(type="text", text="Test prompt") + ) + ], + ) with start_transaction(name="mcp tx"): - tool1("tool_a", {}) - tool2("tool_b", {}) - prompt1("prompt_a", {}) + await stdio( + server, + method="tools/call", + params={ + "name": "tool_a", + "arguments": {}, + }, + ) + + await stdio( + server, + method="tools/call", + params={ + "name": "tool_b", + "arguments": {}, + }, + ) + + await stdio( + server, + method="prompts/get", + params={ + "name": "prompt_a", + "arguments": {}, + }, + ) (tx,) = events assert tx["type"] == "transaction" @@ -765,11 +924,12 @@ def prompt1(name, arguments): assert "prompts/get prompt_a" in span_descriptions +@pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (False, False)], ) -def test_prompt_with_dict_result( +async def test_prompt_with_dict_result( sentry_init, capture_events, send_default_pii, include_prompts ): """Test prompt handler with dict result instead of GetPromptResult object""" @@ -782,10 +942,6 @@ def test_prompt_with_dict_result( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-dict-prompt", transport="stdio") - request_ctx.set(mock_ctx) - @server.get_prompt() def test_prompt_dict(name, arguments): # Return dict format instead of GetPromptResult object @@ -796,7 +952,14 @@ def test_prompt_dict(name, arguments): } with start_transaction(name="mcp tx"): - test_prompt_dict("dict_prompt", {}) + await stdio( + server, + method="prompts/get", + params={ + "name": "dict_prompt", + "arguments": {}, + }, + ) (tx,) = events span = tx["spans"][0] @@ -816,7 +979,9 @@ def test_prompt_dict(name, arguments): assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"] -def test_resource_without_protocol(sentry_init, capture_events): +@pytest.mark.asyncio +@pytest.mark.skip +async def test_resource_without_protocol(sentry_init, capture_events): """Test resource handler with URI without protocol scheme""" sentry_init( integrations=[MCPIntegration()], @@ -826,17 +991,18 @@ def test_resource_without_protocol(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-no-proto", transport="stdio") - request_ctx.set(mock_ctx) - @server.read_resource() def test_resource(uri): return {"data": "test"} with start_transaction(name="mcp tx"): - # URI without protocol - test_resource("simple-path") + await stdio( + server, + method="resources/read", + params={ + "uri": "https://example.com/resource", + }, + ) (tx,) = events span = tx["spans"][0] @@ -846,7 +1012,8 @@ def test_resource(uri): assert SPANDATA.MCP_RESOURCE_PROTOCOL not in span["data"] -def test_tool_with_complex_arguments(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_tool_with_complex_arguments(sentry_init, capture_events): """Test tool handler with complex nested arguments""" sentry_init( integrations=[MCPIntegration()], @@ -856,12 +1023,8 @@ def test_tool_with_complex_arguments(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-complex", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool_complex(tool_name, arguments): + async def test_tool_complex(tool_name, arguments): return {"processed": True} with start_transaction(name="mcp tx"): @@ -870,7 +1033,14 @@ def test_tool_complex(tool_name, arguments): "string": "test", "number": 42, } - test_tool_complex("complex_tool", complex_args) + await stdio( + server, + method="tools/call", + params={ + "name": "complex_tool", + "arguments": complex_args, + }, + ) (tx,) = events span = tx["spans"][0] @@ -988,7 +1158,8 @@ def test_tool(tool_name, arguments): assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-http-456" -def test_stdio_transport_detection(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_stdio_transport_detection(sentry_init, capture_events): """Test that stdio transport is correctly detected when no HTTP request""" sentry_init( integrations=[MCPIntegration()], @@ -998,18 +1169,21 @@ def test_stdio_transport_detection(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context with stdio transport (no HTTP request) - mock_ctx = MockRequestContext(request_id="req-stdio", transport="stdio") - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool(tool_name, arguments): + async def test_tool(tool_name, arguments): return {"result": "success"} with start_transaction(name="mcp tx"): - result = test_tool("stdio_tool", {}) + result = await stdio( + server, + method="tools/call", + params={ + "name": "stdio_tool", + "arguments": {}, + }, + ) - assert result == {"result": "success"} + assert result.message.root.result["structuredContent"] == {"result": "success"} (tx,) = events span = tx["spans"][0]