diff --git a/sentry_sdk/integrations/mcp.py b/sentry_sdk/integrations/mcp.py index 47fda272b7..b03926c6da 100644 --- a/sentry_sdk/integrations/mcp.py +++ b/sentry_sdk/integrations/mcp.py @@ -21,6 +21,7 @@ try: from mcp.server.lowlevel import Server # type: ignore[import-not-found] from mcp.server.lowlevel.server import request_ctx # type: ignore[import-not-found] + from mcp.server.streamable_http import StreamableHTTPServerTransport # type: ignore[import-not-found] except ImportError: raise DidNotEnable("MCP SDK not installed") @@ -31,7 +32,9 @@ if TYPE_CHECKING: - from typing import Any, Callable, Optional + from typing import Any, Callable, Optional, Tuple + + from starlette.types import Receive, Scope, Send # type: ignore[import-not-found] class MCPIntegration(Integration): @@ -54,11 +57,34 @@ def setup_once() -> None: Patches MCP server classes to instrument handler execution. """ _patch_lowlevel_server() + _patch_handle_request() if FastMCP is not None: _patch_fastmcp() +def _get_active_http_scopes() -> ( + "Optional[Tuple[Optional[sentry_sdk.Scope], Optional[sentry_sdk.Scope]]]" +): + try: + ctx = request_ctx.get() + except LookupError: + return None + + if ( + ctx is None + or not hasattr(ctx, "request") + or ctx.request is None + or "state" not in ctx.request.scope + ): + return None + + return ( + ctx.request.scope["state"].get("sentry_sdk.current_scope"), + ctx.request.scope["state"].get("sentry_sdk.isolation_scope"), + ) + + def _get_request_context_data() -> "tuple[Optional[str], Optional[str], str]": """ Extract request ID, session ID, and MCP transport type from the request context. @@ -381,56 +407,67 @@ async def _async_handler_wrapper( result_data_key, ) = _prepare_handler_data(handler_type, original_args, original_kwargs) - # Start span and execute - with get_start_span_function()( - op=OP.MCP_SERVER, - name=span_name, - origin=MCPIntegration.origin, - ) as span: - # Get request ID, session ID, and transport from context - request_id, session_id, mcp_transport = _get_request_context_data() - - # Set input span data - _set_span_input_data( - span, - handler_name, - span_data_key, - mcp_method_name, - arguments, - request_id, - session_id, - mcp_transport, - ) + scopes = _get_active_http_scopes() - # For resources, extract and set protocol - if handler_type == "resource": - if original_args: - uri = original_args[0] - else: - uri = original_kwargs.get("uri") + if scopes is None: + current_scope = None + isolation_scope = None + else: + current_scope, isolation_scope = scopes - protocol = None - if hasattr(uri, "scheme"): - protocol = uri.scheme - elif handler_name and "://" in handler_name: - protocol = handler_name.split("://")[0] - if protocol: - span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol) + # Get request ID, session ID, and transport from context + request_id, session_id, mcp_transport = _get_request_context_data() - try: - # Execute the async handler - if self is not None: - original_args = (self, *original_args) - result = await func(*original_args, **original_kwargs) - except Exception as e: - # Set error flag for tools - if handler_type == "tool": - span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True) - sentry_sdk.capture_exception(e) - raise + # Start span and execute + with sentry_sdk.scope.use_isolation_scope(isolation_scope): + with sentry_sdk.scope.use_scope(current_scope): + with get_start_span_function()( + op=OP.MCP_SERVER, + name=span_name, + origin=MCPIntegration.origin, + ) as span: + # Set input span data + _set_span_input_data( + span, + handler_name, + span_data_key, + mcp_method_name, + arguments, + request_id, + session_id, + mcp_transport, + ) + + # For resources, extract and set protocol + if handler_type == "resource": + if original_args: + uri = original_args[0] + else: + uri = original_kwargs.get("uri") + + protocol = None + if hasattr(uri, "scheme"): + protocol = uri.scheme + elif handler_name and "://" in handler_name: + protocol = handler_name.split("://")[0] + if protocol: + span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol) + + try: + # Execute the async handler + if self is not None: + original_args = (self, *original_args) + result = await func(*original_args, **original_kwargs) + except Exception as e: + # Set error flag for tools + if handler_type == "tool": + span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True) + sentry_sdk.capture_exception(e) + raise + + _set_span_output_data(span, result, result_data_key, handler_type) - _set_span_output_data(span, result, result_data_key, handler_type) - return result + return result def _sync_handler_wrapper( @@ -618,6 +655,25 @@ def patched_read_resource( Server.read_resource = patched_read_resource +def _patch_handle_request() -> None: + original_handle_request = StreamableHTTPServerTransport.handle_request + + @wraps(original_handle_request) + async def patched_handle_request( + self: "StreamableHTTPServerTransport", + scope: "Scope", + receive: "Receive", + send: "Send", + ) -> None: + scope.setdefault("state", {})["sentry_sdk.isolation_scope"] = ( + sentry_sdk.get_isolation_scope() + ) + scope["state"]["sentry_sdk.current_scope"] = sentry_sdk.get_current_scope() + await original_handle_request(self, scope, receive, send) + + StreamableHTTPServerTransport.handle_request = patched_handle_request + + def _patch_fastmcp() -> None: """ Patches the standalone fastmcp package's FastMCP class. diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index 6df26690c8..1e401dcfac 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -100,6 +100,7 @@ F = TypeVar("F", bound=Callable[..., Any]) T = TypeVar("T") + S = TypeVar("S", bound=Optional["Scope"]) # Holds data that will be added to **all** events sent by this process. @@ -1786,7 +1787,7 @@ def new_scope() -> "Generator[Scope, None, None]": @contextmanager -def use_scope(scope: "Scope") -> "Generator[Scope, None, None]": +def use_scope(scope: "S") -> "Generator[S, None, None]": """ .. versionadded:: 2.0.0 @@ -1808,6 +1809,10 @@ def use_scope(scope: "Scope") -> "Generator[Scope, None, None]": sentry_sdk.capture_message("hello, again") # will NOT include `color` tag. """ + if scope is None: + yield scope + return + # set given scope as current scope token = _current_scope.set(scope) @@ -1871,7 +1876,7 @@ def isolation_scope() -> "Generator[Scope, None, None]": @contextmanager -def use_isolation_scope(isolation_scope: "Scope") -> "Generator[Scope, None, None]": +def use_isolation_scope(isolation_scope: "S") -> "Generator[S, None, None]": """ .. versionadded:: 2.0.0 @@ -1892,6 +1897,10 @@ def use_isolation_scope(isolation_scope: "Scope") -> "Generator[Scope, None, Non sentry_sdk.capture_message("hello, again") # will NOT include `color` tag. """ + if isolation_scope is None: + yield isolation_scope + return + # fork current scope current_scope = Scope.get_current_scope() forked_current_scope = current_scope.fork() diff --git a/tests/integrations/fastmcp/test_fastmcp.py b/tests/integrations/fastmcp/test_fastmcp.py index ef2a1f9cb7..4f2d0e6916 100644 --- a/tests/integrations/fastmcp/test_fastmcp.py +++ b/tests/integrations/fastmcp/test_fastmcp.py @@ -262,6 +262,7 @@ class MockHTTPRequest: def __init__(self, session_id=None, transport="http"): self.headers = {} self.query_params = {} + self.scope = {} if transport == "sse": # SSE transport uses query parameter diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 4415467cd7..cb49e4c895 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -28,8 +28,15 @@ async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) +from mcp.types import GetPromptResult, PromptMessage, TextContent +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel import Server from mcp.server.lowlevel.server import request_ctx +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + +from starlette.routing import Mount +from starlette.applications import Starlette +from starlette.testclient import TestClient try: from mcp.server.lowlevel.server import request_ctx @@ -41,6 +48,77 @@ async def __call__(self, *args, **kwargs): from sentry_sdk.integrations.mcp import MCPIntegration +def json_rpc(app, method: str, params, request_id: str | None = None): + if request_id is None: + request_id = "1" # arbitrary + + with TestClient(app) as client: + init_response = client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-11-25", + "capabilities": {}, + }, + "id": request_id, + }, + ) + + session_id = init_response.headers["mcp-session-id"] + + # Notification response is mandatory. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle + client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + "mcp-session-id": session_id, + }, + json={ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }, + ) + + response = client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + "mcp-session-id": session_id, + }, + json={ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": request_id, + }, + ) + + return session_id, response + + +def select_transactions_with_mcp_spans(events, method_name): + return [ + transaction + for transaction in events + if transaction["type"] == "transaction" + and any( + span["data"].get("mcp.method.name") == method_name + for span in transaction.get("spans", []) + ) + ] + + @pytest.fixture(autouse=True) def reset_request_ctx(): """Reset request context before and after each test""" @@ -221,22 +299,46 @@ async def test_tool_handler_async( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-456", session_id="session-789", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.call_tool() async def test_tool_async(tool_name, arguments): - return {"status": "completed"} + return [ + TextContent( + type="text", + text=json.dumps({"status": "completed"}), + ) + ] - with start_transaction(name="mcp tx"): - result = await test_tool_async("process", {"data": "test"}) + session_id, result = json_rpc( + app, + method="tools/call", + params={ + "name": "process", + "arguments": { + "data": "test", + }, + }, + request_id="req-456", + ) + assert result.json()["result"]["content"][0]["text"] == json.dumps( + {"status": "completed"} + ) - assert result == {"status": "completed"} + transactions = select_transactions_with_mcp_spans(events, "tools/call") + assert len(transactions) == 1 + tx = transactions[0] - (tx,) = events assert tx["type"] == "transaction" assert len(tx["spans"]) == 1 @@ -250,13 +352,16 @@ async def test_tool_async(tool_name, arguments): assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call" assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-456" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-789" + assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id assert span["data"]["mcp.request.argument.data"] == '"test"' # Check PII-sensitive data if send_default_pii and include_prompts: + # TODO: Investigate why tool result is double-serialized. assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == json.dumps( - {"status": "completed"} + json.dumps( + {"status": "completed"}, + ) ) else: assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"] @@ -385,27 +490,49 @@ async def test_prompt_handler_async( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-async-prompt", session_id="session-abc", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.get_prompt() async def test_prompt_async(name, arguments): - return MockGetPromptResult( - [ - MockPromptMessage("system", "You are a helpful assistant"), - MockPromptMessage("user", "What is MCP?"), - ] + return GetPromptResult( + description="A helpful test prompt", + messages=[ + PromptMessage( + role="user", + content=TextContent( + type="text", text="You are a helpful assistant" + ), + ), + PromptMessage( + role="user", content=TextContent(type="text", text="What is MCP?") + ), + ], ) - with start_transaction(name="mcp tx"): - result = await test_prompt_async("mcp_info", {}) + _, result = json_rpc( + app, + method="prompts/get", + params={ + "name": "mcp_info", + "arguments": {}, + }, + ) + assert len(result.json()["result"]["messages"]) == 2 - assert len(result.messages) == 2 + transactions = select_transactions_with_mcp_spans(events, "prompts/get") + assert len(transactions) == 1 + tx = transactions[0] - (tx,) = events assert tx["type"] == "transaction" assert len(tx["spans"]) == 1 @@ -504,23 +631,42 @@ async def test_resource_handler_async(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-async-resource", session_id="session-res", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.read_resource() async def test_resource_async(uri): - return {"data": "resource data"} + return [ + ReadResourceContents( + content=json.dumps({"data": "resource data"}), mime_type="text/plain" + ) + ] - with start_transaction(name="mcp tx"): - uri = MockURI("https://example.com/resource") - result = await test_resource_async(uri) + session_id, result = json_rpc( + app, + method="resources/read", + params={ + "uri": "https://example.com/resource", + }, + ) - assert result["data"] == "resource data" + assert result.json()["result"]["contents"][0]["text"] == json.dumps( + {"data": "resource data"} + ) + + transactions = select_transactions_with_mcp_spans(events, "resources/read") + assert len(transactions) == 1 + tx = transactions[0] - (tx,) = events assert tx["type"] == "transaction" assert len(tx["spans"]) == 1 @@ -530,7 +676,7 @@ async def test_resource_async(uri): assert span["data"][SPANDATA.MCP_RESOURCE_URI] == "https://example.com/resource" assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "https" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-res" + assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id def test_resource_handler_with_error(sentry_init, capture_events): @@ -964,28 +1110,49 @@ def test_streamable_http_transport_detection(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context with StreamableHTTP transport - mock_ctx = MockRequestContext( - request_id="req-http", session_id="session-http-456", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.call_tool() - def test_tool(tool_name, arguments): - return {"result": "success"} + async def test_tool(tool_name, arguments): + return [ + TextContent( + type="text", + text=json.dumps({"status": "success"}), + ) + ] - with start_transaction(name="mcp tx"): - result = test_tool("http_tool", {}) + _, result = json_rpc( + app, + method="tools/call", + params={ + "name": "http_tool", + "arguments": {}, + }, + ) + assert result.json()["result"]["content"][0]["text"] == json.dumps( + {"status": "success"} + ) - assert result == {"result": "success"} + transactions = select_transactions_with_mcp_spans(events, "tools/call") + assert len(transactions) == 1 + tx = transactions[0] - (tx,) = events span = tx["spans"][0] # Check that HTTP transport is detected assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-http-456" + assert len(span["data"][SPANDATA.MCP_SESSION_ID]) == 32 def test_stdio_transport_detection(sentry_init, capture_events):