Skip to content
2 changes: 1 addition & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def send_log_message(
related_request_id,
)

async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover
async def send_resource_updated(self, uri: str | AnyUrl) -> None:
"""Send a resource updated notification."""
await self.send_notification(
types.ResourceUpdatedNotification(
Expand Down
12 changes: 6 additions & 6 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover
if scope["type"] != "http":
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http": # pragma: no cover
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")

Expand Down Expand Up @@ -195,7 +195,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send):
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)

async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)

Expand All @@ -205,15 +205,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
return await error_response(scope, receive, send)

session_id_param = request.query_params.get("session_id")
if session_id_param is None:
if session_id_param is None: # pragma: no cover
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)

try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
except ValueError: # pragma: no cover
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
Expand All @@ -230,7 +230,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
try:
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
except ValidationError as err: # pragma: no cover
logger.exception("Failed to parse message")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
Expand Down
66 changes: 33 additions & 33 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def is_terminated(self) -> bool:
"""Check if this transport has been explicitly terminated."""
return self._terminated

def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
def close_sse_stream(self, request_id: RequestId) -> None:
"""Close SSE connection for a specific request without terminating the stream.

This method closes the HTTP connection for the specified request, triggering
Expand All @@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
writer.close()

# Also close and remove request streams
if request_id in self._request_streams:
if request_id in self._request_streams: # pragma: no branch
send_stream, receive_stream = self._request_streams.pop(request_id)
send_stream.close()
receive_stream.close()

def close_standalone_sse_stream(self) -> None: # pragma: no cover
def close_standalone_sse_stream(self) -> None:
"""Close the standalone GET SSE stream, triggering client reconnection.

This method closes the HTTP connection for the standalone GET stream used
Expand Down Expand Up @@ -240,10 +240,10 @@ def _create_session_message(
# Only provide close callbacks when client supports resumability
if self._event_store and protocol_version >= "2025-11-25":

async def close_stream_callback() -> None: # pragma: no cover
async def close_stream_callback() -> None:
self.close_sse_stream(request_id)

async def close_standalone_stream_callback() -> None: # pragma: no cover
async def close_standalone_stream_callback() -> None:
self.close_standalone_sse_stream()

metadata = ServerMessageMetadata(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _create_error_response(
) -> Response:
"""Create an error response with a simple string message."""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
if headers: # pragma: no cover
if headers:
response_headers.update(headers)

if self.mcp_session_id:
Expand Down Expand Up @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
}

# If an event ID was provided, include it
if event_message.event_id: # pragma: no cover
if event_message.event_id:
event_data["id"] = event_message.event_id

return event_data
Expand Down Expand Up @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await error_response(scope, receive, send)
return

if self._terminated: # pragma: no cover
if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
"Not Found: Session has been terminated",
Expand All @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_get_request(request, send)
elif request.method == "DELETE":
await self._handle_delete_request(request, send)
else: # pragma: no cover
else:
await self._handle_unsupported_request(request, send)

def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
Expand Down Expand Up @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

try:
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
except ValidationError as e:
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
Expand All @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
)
await response(scope, receive, send)
return
elif not await self._validate_request_headers(request, send): # pragma: no cover
elif not await self._validate_request_headers(request, send):
return

# For notifications and responses only, return 202 Accepted
Expand Down Expand Up @@ -659,19 +659,19 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
# Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request)

if not has_sse: # pragma: no cover
if not has_sse:
response = self._create_error_response(
"Not Acceptable: Client must accept text/event-stream",
HTTPStatus.NOT_ACCEPTABLE,
)
await response(request.scope, request.receive, send)
return

if not await self._validate_request_headers(request, send): # pragma: no cover
if not await self._validate_request_headers(request, send):
return

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
return

Expand All @@ -681,11 +681,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Check if we already have an active GET stream
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
if GET_STREAM_KEY in self._request_streams:
response = self._create_error_response(
"Conflict: Only one SSE stream is allowed per session",
HTTPStatus.CONFLICT,
Expand Down Expand Up @@ -714,7 +714,7 @@ async def standalone_sse_writer():
# Send the message via SSE
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception: # pragma: no cover
except Exception:
logger.exception("Error in standalone SSE writer")
finally:
logger.debug("Closing standalone SSE writer")
Expand Down Expand Up @@ -791,13 +791,13 @@ async def terminate(self) -> None:
# During cleanup, we catch all exceptions since streams might be in various states
logger.debug(f"Error closing streams: {e}")

async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
headers = {
"Content-Type": CONTENT_TYPE_JSON,
"Allow": "GET, POST, DELETE",
}
if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

response = self._create_error_response(
Expand All @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
request_session_id = self._get_session_id(request)

# If no session ID provided but required, return error
if not request_session_id: # pragma: no cover
if not request_session_id:
response = self._create_error_response(
"Bad Request: Missing session ID",
HTTPStatus.BAD_REQUEST,
Expand All @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None: # pragma: no cover
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
Expand All @@ -865,13 +865,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""Replays events that would have been sent after the specified event ID.

Only used when resumability is enabled.
"""
event_store = self._event_store
if not event_store:
if not event_store: # pragma: no cover
return

try:
Expand All @@ -881,7 +881,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Get protocol version from header (already validated in _validate_protocol_version)
Expand All @@ -902,7 +902,7 @@ async def send_event(event_message: EventMessage) -> None:
stream_id = await event_store.replay_events_after(last_event_id, send_event)

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
if stream_id and stream_id not in self._request_streams: # pragma: no branch
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

Expand All @@ -921,9 +921,9 @@ async def send_event(event_message: EventMessage) -> None:
await sse_stream_writer.send(event_data)
except anyio.ClosedResourceError:
# Expected when close_sse_stream() is called
logger.debug("Replay SSE stream closed by close_sse_stream()")
logger.debug("Replay SSE stream closed by close_sse_stream()") # pragma: no cover
except Exception:
logger.exception("Error in replay sender")
logger.exception("Error in replay sender") # pragma: no cover

# Create and start EventSourceResponse
response = EventSourceResponse(
Expand All @@ -934,13 +934,13 @@ async def send_event(event_message: EventMessage) -> None:

try:
await response(request.scope, request.receive, send)
except Exception:
except Exception: # pragma: no cover
logger.exception("Error in replay response")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()

except Exception:
except Exception: # pragma: no cover
logger.exception("Error replaying events")
response = self._create_error_response(
"Error replaying events",
Expand Down Expand Up @@ -991,7 +991,7 @@ async def message_router():
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
target_request_id = str(message.id)
# Extract related_request_id from meta if it exists
elif ( # pragma: no cover
elif (
session_message.metadata is not None
and isinstance(
session_message.metadata,
Expand All @@ -1015,10 +1015,10 @@ async def message_router():
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else: # pragma: no cover
else:
logger.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
Expand Down
24 changes: 12 additions & 12 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
if not host: # pragma: no cover
logger.warning("Missing Host header in request")
return False

Expand All @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins:
if origin in self.settings.allowed_origins: # pragma: no cover
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
if allowed.endswith(":*"): # pragma: no branch
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
Expand Down Expand Up @@ -104,13 +104,13 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=421)

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=403)

return None # pragma: no cover
return None
Loading
Loading