diff --git a/docs/migration.md b/docs/migration.md index 3b47f9aad..0b9ea43b8 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -879,6 +879,21 @@ app = server.streamable_http_app( The lowlevel `Server` also now exposes a `session_manager` property to access the `StreamableHTTPSessionManager` after calling `streamable_http_app()`. +### Multi-tenancy support + +The SDK now supports multi-tenant deployments where a single server instance serves multiple isolated tenants. Tenant identity flows from authentication tokens through sessions, request context, and into all handler invocations. + +Key additions: + +- `AccessToken.tenant_id` — carries tenant identity in OAuth tokens +- `Context.tenant_id` — available in tool, resource, and prompt handlers +- `server.add_tool(fn, tenant_id="...")`, `server.add_resource(r, tenant_id="...")`, `server.add_prompt(p, tenant_id="...")` — register tenant-scoped tools, resources, and prompts +- `StreamableHTTPSessionManager` — validates tenant identity on every request and prevents cross-tenant session access + +All APIs default to `tenant_id=None`, preserving backward compatibility for single-tenant servers. + +See the [Multi-Tenancy Guide](multi-tenancy.md) for details. + ## Need Help? If you encounter issues during migration: diff --git a/docs/multi-tenancy.md b/docs/multi-tenancy.md new file mode 100644 index 000000000..f3a8e6932 --- /dev/null +++ b/docs/multi-tenancy.md @@ -0,0 +1,255 @@ +# Multi-Tenancy Guide + +This guide explains how to build MCP servers that safely isolate multiple tenants sharing a single server instance. Multi-tenancy ensures that tools, resources, prompts, and sessions belonging to one tenant are invisible and inaccessible to others. + +> For a complete working example, see [`examples/servers/simple-multi-tenant/`](../examples/servers/simple-multi-tenant/). + +## Overview + +In a multi-tenant deployment, a single MCP server process serves requests from multiple organizations, teams, or users (tenants). Without proper isolation, Tenant A could list or invoke Tenant B's tools, read their resources, or hijack their sessions. + +The MCP Python SDK provides built-in tenant isolation across all layers: + +- **Authentication tokens** carry a `tenant_id` field +- **Sessions** are bound to a single tenant on first authenticated request +- **Request context** propagates `tenant_id` to every handler +- **Managers** (tools, resources, prompts) use tenant-scoped storage +- **Session manager** validates tenant identity on every request + +## How It Works + +### Tenant Identification Flow + +```mermaid +flowchart TD + A["HTTP Request"] --> B["AuthContextMiddleware"] + B -->|"extracts tenant_id from AccessToken
sets tenant_id_var (contextvar)"| C["StreamableHTTPSessionManager"] + C -->|"binds new sessions to the current tenant
rejects cross-tenant session access (HTTP 404)"| D["Low-level Server
(_handle_request / _handle_notification)"] + D -->|"reads tenant_id_var
sets session.tenant_id (set-once)
populates ServerRequestContext.tenant_id"| E["MCPServer handlers
(_handle_list_tools, _handle_call_tool, etc.)"] + E -->|"passes ctx.tenant_id to managers"| F["ToolManager / ResourceManager / PromptManager"] + F -->|"looks up (tenant_id, name) in nested dict
returns only the requesting tenant's entries"| G["Response"] +``` + +### Key Components + +| Component | File | Role | +|---|---|---| +| `AccessToken.tenant_id` | `server/auth/provider.py` | Carries tenant identity in OAuth tokens | +| `tenant_id_var` | `shared/_context.py` | Transport-agnostic contextvar for tenant propagation | +| `AuthContextMiddleware` | `server/auth/middleware/auth_context.py` | Extracts tenant from auth and sets contextvar | +| `ServerSession.tenant_id` | `server/session.py` | Binds session to tenant (set-once semantics) | +| `ServerRequestContext.tenant_id` | `shared/_context.py` | Per-request tenant context for handlers | +| `Context.tenant_id` | `server/mcpserver/context.py` | High-level property for tool/resource/prompt handlers | +| `ToolManager` | `server/mcpserver/tools/tool_manager.py` | Tenant-scoped tool storage | +| `ResourceManager` | `server/mcpserver/resources/resource_manager.py` | Tenant-scoped resource storage | +| `PromptManager` | `server/mcpserver/prompts/manager.py` | Tenant-scoped prompt storage | +| `StreamableHTTPSessionManager` | `server/streamable_http_manager.py` | Validates tenant on session access | + +## Usage + +### Simple Registering Tenant-Scoped Tools, Resources, and Prompts + +Use the `tenant_id` parameter when adding tools, resources, or prompts: + +```python +from mcp.server.mcpserver import MCPServer + +server = MCPServer("my-server") + +# Register a tool for a specific tenant +def analyze_data(query: str) -> str: + return f"Results for: {query}" + +server.add_tool(analyze_data, tenant_id="acme-corp") + +# Register a resource for a specific tenant +from mcp.server.mcpserver.resources.types import FunctionResource + +server.add_resource( + FunctionResource(uri="data://config", name="config", fn=lambda: "tenant config"), + tenant_id="acme-corp", +) + +# Register a prompt for a specific tenant +from mcp.server.mcpserver.prompts.base import Prompt + +async def onboarding_prompt() -> str: + return "Welcome to Acme Corp!" + +server.add_prompt( + Prompt.from_function(onboarding_prompt, name="onboarding"), + tenant_id="acme-corp", +) +``` + +The same name can be registered under different tenants without conflict: + +```python +server.add_tool(acme_tool, name="analyze", tenant_id="acme-corp") +server.add_tool(globex_tool, name="analyze", tenant_id="globex-inc") +``` + +### Dynamic Tenant Provisioning + +Multi-tenancy enables MCP servers to operate as SaaS platforms where tenants are +provisioned and deprovisioned at runtime. Tools, resources, and prompts can be +added or removed dynamically — for example, when a tenant signs up, changes +their subscription tier, or installs a plugin. + +#### Tenant Onboarding and Offboarding + +Register a tenant's capabilities when they sign up and remove them when they leave: + +```python +def onboard_tenant(server: MCPServer, tenant_id: str, plan: str) -> None: + """Provision tools for a new tenant based on their plan.""" + + # Base tools available to all tenants + server.add_tool(search_docs, tenant_id=tenant_id) + server.add_tool(get_status, tenant_id=tenant_id) + + # Premium tools gated by plan + if plan in ("pro", "enterprise"): + server.add_tool(run_analytics, tenant_id=tenant_id) + server.add_tool(export_data, tenant_id=tenant_id) + + +def offboard_tenant(server: MCPServer, tenant_id: str) -> None: + """Remove all tools when a tenant is deprovisioned.""" + server.remove_tool("search_docs", tenant_id=tenant_id) + server.remove_tool("get_status", tenant_id=tenant_id) + server.remove_tool("run_analytics", tenant_id=tenant_id) + server.remove_tool("export_data", tenant_id=tenant_id) +``` + +#### Plugin Systems + +Let tenants install or uninstall integrations that map to MCP tools: + +```python +def install_plugin(server: MCPServer, tenant_id: str, plugin: str) -> None: + """Install a plugin's tools for a specific tenant.""" + plugin_tools = load_plugin_tools(plugin) # Your plugin registry + for tool_fn in plugin_tools: + server.add_tool(tool_fn, tenant_id=tenant_id) + + +def uninstall_plugin(server: MCPServer, tenant_id: str, plugin: str) -> None: + """Remove a plugin's tools for a specific tenant.""" + plugin_tool_names = get_plugin_tool_names(plugin) + for name in plugin_tool_names: + server.remove_tool(name, tenant_id=tenant_id) +``` + +All dynamic changes take effect immediately — the next `list_tools` request from that tenant will reflect the updated set. Other tenants are unaffected. + +### Accessing Tenant ID in Handlers + +Inside tool, resource, or prompt handlers, access the current tenant through `Context.tenant_id`: + +```python +from mcp.server.mcpserver.context import Context + +@server.tool() +async def get_data(ctx: Context) -> str: + tenant = ctx.tenant_id # e.g., "acme-corp" or None + return f"Data for tenant: {tenant}" +``` + +### Setting Up Authentication with Tenant ID + +The `tenant_id` field on `AccessToken` is populated by your token verifier or OAuth provider. The `AuthContextMiddleware` automatically extracts `tenant_id` from the authenticated user's access token and sets the `tenant_id_var` contextvar for downstream use. + +Implement the `TokenVerifier` protocol to bridge your external identity provider with the MCP auth stack. Your `verify_token` method decodes or introspects the bearer token and returns an `AccessToken` with `tenant_id` populated. + +**Configuring your identity provider to include tenant identity in tokens:** + +Most identity providers allow you to add custom claims to access tokens. The claim name varies by provider, but common conventions include `org_id`, `tenant_id`, or a namespaced claim like `https://myapp.com/tenant_id`. Here are some examples: + +- **Duo Security**: Define a [custom user attribute](https://duo.com/docs/user-attributes) (e.g., `tenant_id`) and assign it to users via Duo Directory sync or the Admin Panel. Include this attribute as a claim in the access token issued by Duo as your IdP. +- **Auth0**: Use [Organizations](https://auth0.com/docs/manage-users/organizations) to model tenants. When a user authenticates through an organization, Auth0 automatically includes an `org_id` claim in the access token. Alternatively, use an [Action](https://auth0.com/docs/customize/actions) on the "Machine to Machine" or "Login" flow to add a custom claim based on app metadata or connection context. +- **Okta**: Add a [custom claim](https://developer.okta.com/docs/guides/customize-tokens-returned-from-okta/) to your authorization server. Map the claim value from the user's profile (e.g., `user.profile.orgId`) or from a group membership. +- **Microsoft Entra ID (Azure AD)**: Use the `tid` (tenant ID) claim that is included by default in tokens, or configure [optional claims](https://learn.microsoft.com/en-us/entra/identity-platform/optional-claims) to add organization-specific attributes. +- **Custom JWT issuer**: Include a `tenant_id` (or equivalent) claim in the JWT payload when minting tokens. For example: `{"sub": "user-123", "tenant_id": "acme-corp", "scope": "read write"}`. + +Once your provider includes the tenant claim, extract it in your `TokenVerifier`: + +```python +from mcp.server.auth.provider import AccessToken, TokenVerifier + + +class JWTTokenVerifier(TokenVerifier): + """Verify JWTs and extract tenant_id from claims.""" + + async def verify_token(self, token: str) -> AccessToken | None: + # Decode and validate the JWT (e.g., using PyJWT or authlib) + claims = decode_and_verify_jwt(token) + if claims is None: + return None + + return AccessToken( + token=token, + client_id=claims["sub"], + scopes=claims.get("scope", "").split(), + expires_at=claims.get("exp"), + tenant_id=claims["org_id"], # Extract tenant from your JWT claims + ) +``` + +Then pass the verifier when creating your `MCPServer`: + +```python +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver.server import MCPServer + +server = MCPServer( + "my-server", + token_verifier=JWTTokenVerifier(), + auth=AuthSettings( + issuer_url="https://auth.example.com", + resource_server_url="https://mcp.example.com", + required_scopes=["read"], + ), +) +``` + +Once the `AccessToken` reaches the middleware stack, the flow is automatic: `BearerAuthBackend` validates the token → `AuthContextMiddleware` extracts `tenant_id` → `tenant_id_var` contextvar is set → all downstream handlers and managers receive the correct tenant scope. + +### Session Isolation + +Sessions are automatically bound to their tenant on first authenticated request (set-once semantics). The `StreamableHTTPSessionManager` enforces this: + +- New sessions record the creating tenant's ID +- Subsequent requests must come from the same tenant +- Cross-tenant session access returns HTTP 404 +- Session tenant binding cannot be changed after initial assignment + +## Backward Compatibility + +All tenant-scoped APIs default to `tenant_id=None`, preserving single-tenant behavior: + +```python +# These all work exactly as before — no tenant scoping +server.add_tool(my_tool) +server.add_resource(my_resource) +await server.list_tools() # Returns tools in global (None) scope +``` + +Tools registered without a `tenant_id` live in the global scope and are only visible when no tenant context is active (i.e., `tenant_id_var` is not set or is `None`). + +## Architecture Notes + +### Storage Model + +Managers use a nested dictionary `{tenant_id: {name: item}}` for O(1) lookups per tenant. When the last item in a tenant scope is removed, the scope dictionary is cleaned up automatically. + +### Set-Once Session Binding + +`ServerSession.tenant_id` uses set-once semantics: once a session is bound to a tenant (on the first request with a non-None tenant_id), it cannot be changed. This prevents session fixation attacks where a session created by one tenant could be reused by another. + +### Security Considerations + +- **Cross-tenant tool invocation**: A tenant can only call tools registered under their own tenant_id. Attempting to call a tool from another tenant's scope raises a `ToolError`. +- **Resource access**: Resources are tenant-scoped. Reading a resource registered under a different tenant raises a `ResourceError`. +- **Session hijacking**: The session manager validates the requesting tenant against the session's bound tenant on every request. Mismatches return HTTP 404 with an opaque "Session not found" error (no tenant information is leaked). +- **Log levels**: Tenant mismatch events are logged at WARNING level (session ID only). Sensitive tenant identifiers are logged at DEBUG level only. diff --git a/examples/servers/simple-multi-tenant/README.md b/examples/servers/simple-multi-tenant/README.md new file mode 100644 index 000000000..b8da91f18 --- /dev/null +++ b/examples/servers/simple-multi-tenant/README.md @@ -0,0 +1,99 @@ +# Multi-Tenant MCP Server Example + +Demonstrates tenant-scoped tools, resources, and prompts using the MCP Python SDK's multi-tenancy support. + +## What it shows + +- **Acme** (analytics company) has `run_query` and `generate_report` tools, a `database-schema` resource, and an `analyst` prompt +- **Globex** (content company) has `publish_article` and `check_seo` tools, a `style-guide` resource, and an `editor` prompt +- Each tenant sees only their own tools, resources, and prompts — Acme cannot see Globex's tools and vice versa +- A `whoami` tool is registered under both tenants and reports the current tenant identity from `Context.tenant_id` + +## Running + +Start the server on the default or custom port: + +```bash +uv run mcp-simple-multi-tenant --port 3000 +``` + +The server starts a StreamableHTTP endpoint at `http://127.0.0.1:3000/mcp`. + +## What each tenant sees + +**Acme** (analytics): +- Tools: `run_query`, `generate_report`, `whoami` +- Resources: `data://schema` (database schema) +- Prompts: `analyst` (data analyst system prompt) + +**Globex** (content): +- Tools: `publish_article`, `check_seo`, `whoami` +- Resources: `content://style-guide` (editorial style guide) +- Prompts: `editor` (content editor system prompt) + +**No tenant** (unauthenticated): sees nothing — all items are tenant-scoped. + +## Example: programmatic client + +You can verify tenant isolation using the MCP client with in-memory transport: + +```python +import asyncio + +from mcp.client.session import ClientSession +from mcp.shared._context import tenant_id_var +from mcp.shared.memory import create_client_server_memory_streams + +from mcp_simple_multi_tenant.server import create_server + + +async def main(): + server = create_server() + actual = server._lowlevel_server + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + import anyio + + async with anyio.create_task_group() as tg: + # Set tenant context for the server side + async def run_server(): + token = tenant_id_var.set("acme") + try: + await actual.run( + server_read, + server_write, + actual.create_initialization_options(), + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + + # Acme sees only analytics tools + tools = await session.list_tools() + print(f"Tools: {[t.name for t in tools.tools]}") + # → ['run_query', 'generate_report', 'whoami'] + + result = await session.call_tool( + "run_query", {"sql": "SELECT * FROM users"} + ) + print(f"Result: {result.content[0].text}") + # → Query result for: SELECT * FROM users (3 rows returned) + + tg.cancel_scope.cancel() + + +asyncio.run(main()) +``` + +## How tenant identity works + +In a production deployment, `tenant_id` is extracted from the OAuth `AccessToken` by the `AuthContextMiddleware` and propagated through the request context automatically — no manual `tenant_id_var.set()` is needed. The in-memory example above sets it manually to simulate what the middleware does. + +See the [Multi-Tenancy Guide](../../../docs/multi-tenancy.md) for the full architecture. diff --git a/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/__init__.py b/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/__main__.py b/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/__main__.py new file mode 100644 index 000000000..e7ef16530 --- /dev/null +++ b/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/server.py b/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/server.py new file mode 100644 index 000000000..b0bbc9090 --- /dev/null +++ b/examples/servers/simple-multi-tenant/mcp_simple_multi_tenant/server.py @@ -0,0 +1,130 @@ +"""Multi-tenant MCP server example. + +Demonstrates how to register tenant-scoped tools, resources, and prompts +so that each tenant sees only their own items. Tenant identity is +determined by the ``tenant_id`` field on the OAuth ``AccessToken`` and +propagated automatically through the request context. + +In this example, "acme" is an analytics company with data tools, while +"globex" is a content company with publishing tools. Each tenant has +completely different capabilities — they share nothing. + +NOTE: This example uses a simple in-memory token verifier for +demonstration purposes. In production, integrate with your OAuth +provider to populate ``AccessToken.tenant_id`` from your auth system. +""" + +import logging + +import click +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.prompts.base import Prompt +from mcp.server.mcpserver.resources.types import FunctionResource +from mcp.server.mcpserver.server import MCPServer + +logger = logging.getLogger(__name__) + + +def create_server() -> MCPServer: + """Create an MCPServer with tenant-scoped tools, resources, and prompts. + + Each tenant has completely different tools, resources, and prompts. + Acme is an analytics company; Globex is a content company. + """ + + server = MCPServer("multi-tenant-demo") + + # -- Tenant "acme" (analytics company) --------------------------------- + + def run_query(sql: str) -> str: + """Execute an analytics query.""" + return f"Query result for: {sql} (3 rows returned)" + + def generate_report(metric: str, period: str) -> str: + """Generate an analytics report.""" + return f"Report: {metric} over {period} — trend is up 12%" + + server.add_tool(run_query, tenant_id="acme") + server.add_tool(generate_report, tenant_id="acme") + + server.add_resource( + FunctionResource( + uri="data://schema", + name="database-schema", + fn=lambda: "tables: users, events, metrics", + ), + tenant_id="acme", + ) + + async def acme_analyst_prompt() -> str: + return "You are a data analyst. Help the user write SQL queries and interpret results." + + server.add_prompt(Prompt.from_function(acme_analyst_prompt, name="analyst"), tenant_id="acme") + + # -- Tenant "globex" (content company) --------------------------------- + + def publish_article(title: str, body: str) -> str: + """Publish an article to the CMS.""" + return f"Published: '{title}' ({len(body)} chars)" + + def check_seo(url: str) -> str: + """Check SEO score for a URL.""" + return f"SEO score for {url}: 87/100 — missing meta description" + + server.add_tool(publish_article, tenant_id="globex") + server.add_tool(check_seo, tenant_id="globex") + + server.add_resource( + FunctionResource( + uri="content://style-guide", + name="style-guide", + fn=lambda: "Tone: professional but approachable. Max paragraph length: 3 sentences.", + ), + tenant_id="globex", + ) + + async def globex_editor_prompt() -> str: + return "You are a content editor. Help the user write and publish articles." + + server.add_prompt(Prompt.from_function(globex_editor_prompt, name="editor"), tenant_id="globex") + + # -- Shared "whoami" tool (registered per tenant) ---------------------- + # There is no global scope fallback — tools must be registered under + # each tenant that needs them. + + def whoami(ctx: Context) -> str: + """Return the current tenant identity.""" + return f"tenant: {ctx.tenant_id or 'anonymous'}" + + server.add_tool(whoami, name="whoami", tenant_id="acme") + server.add_tool(whoami, name="whoami", tenant_id="globex") + + return server + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +def main(port: int, log_level: str) -> int: + """Run the multi-tenant MCP demo server. + + Acme (analytics) and Globex (content) each have completely different + tools, resources, and prompts. Neither tenant can see the other's items. + """ + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + server = create_server() + logger.info(f"Starting multi-tenant MCP server on port {port}") + server.run(transport="streamable-http", host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-multi-tenant/pyproject.toml b/examples/servers/simple-multi-tenant/pyproject.toml new file mode 100644 index 000000000..62db20b2e --- /dev/null +++ b/examples/servers/simple-multi-tenant/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-multi-tenant" +version = "0.1.0" +description = "A simple MCP server demonstrating multi-tenant isolation" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] +keywords = ["mcp", "llm", "multi-tenant"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.2.0", "mcp"] + +[project.scripts] +mcp-simple-multi-tenant = "mcp_simple_multi_tenant.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_multi_tenant"] + +[tool.pyright] +include = ["mcp_simple_multi_tenant"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..2a9a5c0fd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -408,7 +408,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546..2cee836e1 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -4,6 +4,7 @@ from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.server.auth.provider import AccessToken +from mcp.shared._context import tenant_id_var # Create a contextvar to store the authenticated user # The default is None, indicating no authenticated user is present @@ -20,6 +21,16 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None +def get_tenant_id() -> str | None: + """Get the tenant_id from the current authentication context. + + Returns: + The tenant_id if an authenticated user with a tenant is available, None otherwise. + """ + access_token = get_access_token() + return access_token.tenant_id if access_token else None + + class AuthContextMiddleware: """Middleware that extracts the authenticated user from the request and sets it in a contextvar for easy access throughout the request lifecycle. @@ -36,11 +47,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): user = scope.get("user") if isinstance(user, AuthenticatedUser): # Set the authenticated user in the contextvar - token = auth_context_var.set(user) + auth_token = auth_context_var.set(user) + # Propagate tenant_id to the transport-agnostic contextvar + tenant_id = user.access_token.tenant_id if user.access_token else None + tenant_token = tenant_id_var.set(tenant_id) try: await self.app(scope, receive, send) finally: - auth_context_var.reset(token) + tenant_id_var.reset(tenant_token) + auth_context_var.reset(auth_token) else: # No authenticated user, just process the request await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9..2eafdc793 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -95,7 +95,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: # pragma: no cover + if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a85..2b0a4ad53 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -25,6 +25,7 @@ class AuthorizationCode(BaseModel): redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool resource: str | None = None # RFC 8707 resource indicator + tenant_id: str | None = None # Tenant this code belongs to class RefreshToken(BaseModel): @@ -32,6 +33,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + tenant_id: str | None = None # Tenant this token belongs to class AccessToken(BaseModel): @@ -40,6 +42,7 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + tenant_id: str | None = None # Tenant this token belongs to RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 167f34b8b..9316daad8 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -65,6 +65,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._context import tenant_id_var from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -451,11 +452,15 @@ async def _handle_request( task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) + tenant_id = tenant_id_var.get() + if tenant_id is not None and session.tenant_id is None: + session.tenant_id = tenant_id ctx = ServerRequestContext( request_id=message.request_id, meta=message.request_meta, session=session, lifespan_context=lifespan_context, + tenant_id=tenant_id, experimental=Experimental( task_metadata=task_metadata, _client_capabilities=client_capabilities, @@ -495,9 +500,13 @@ async def _handle_notification( try: client_capabilities = session.client_params.capabilities if session.client_params else None task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + tenant_id = tenant_id_var.get() + if tenant_id is not None and session.tenant_id is None: + session.tenant_id = tenant_id ctx = ServerRequestContext( session=session, lifespan_context=lifespan_context, + tenant_id=tenant_id, experimental=Experimental( task_metadata=None, _client_capabilities=client_capabilities, @@ -553,7 +562,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -579,7 +588,7 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None if auth and auth.resource_server_url: @@ -602,7 +611,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c..0fc3b724f 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -69,6 +69,13 @@ def __init__( self._request_context = request_context self._mcp_server = mcp_server + @property + def tenant_id(self) -> str | None: + """Get the tenant_id for this request, if available.""" + if self._request_context is not None: + return self._request_context.tenant_id + return None + @property def mcp_server(self) -> MCPServer: """Access to the MCPServer instance.""" @@ -114,7 +121,7 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent The resource content as either text or bytes """ assert self._mcp_server is not None, "Context is not available outside of a request" - return await self._mcp_server.read_resource(uri, self) + return await self._mcp_server.read_resource(uri, self, tenant_id=self.tenant_id) async def elicit( self, diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 28a7a6e98..7c5144082 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -15,44 +15,68 @@ class PromptManager: - """Manages MCPServer prompts.""" + """Manages MCPServer prompts with optional tenant-scoped storage. + + Prompts are stored in a nested dict: ``{tenant_id: {prompt_name: Prompt}}``. + This allows the same prompt name to exist independently under different + tenants with O(1) lookups per tenant. When ``tenant_id`` is ``None`` + (the default), prompts live in a global scope, preserving backward + compatibility with single-tenant usage. + + Note: This class is not thread-safe. It is designed to run within a + single-threaded async event loop, where all synchronous mutations + execute atomically. Do not share instances across OS threads without + external synchronization. + """ def __init__(self, warn_on_duplicate_prompts: bool = True): - self._prompts: dict[str, Prompt] = {} + self._prompts: dict[str | None, dict[str, Prompt]] = {} self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - def get_prompt(self, name: str) -> Prompt | None: - """Get prompt by name.""" - return self._prompts.get(name) + def get_prompt(self, name: str, *, tenant_id: str | None = None) -> Prompt | None: + """Get prompt by name, optionally scoped to a tenant.""" + return self._prompts.get(tenant_id, {}).get(name) - def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" - return list(self._prompts.values()) + def list_prompts(self, *, tenant_id: str | None = None) -> list[Prompt]: + """List all registered prompts for a given tenant scope.""" + return list(self._prompts.get(tenant_id, {}).values()) def add_prompt( self, prompt: Prompt, + *, + tenant_id: str | None = None, ) -> Prompt: - """Add a prompt to the manager.""" - - # Check for duplicates - existing = self._prompts.get(prompt.name) + """Add a prompt to the manager, optionally scoped to a tenant.""" + scope = self._prompts.setdefault(tenant_id, {}) + existing = scope.get(prompt.name) if existing: if self.warn_on_duplicate_prompts: logger.warning(f"Prompt already exists: {prompt.name}") return existing - self._prompts[prompt.name] = prompt + scope[prompt.name] = prompt return prompt + def remove_prompt(self, name: str, *, tenant_id: str | None = None) -> None: + """Remove a prompt by name, optionally scoped to a tenant.""" + scope = self._prompts.get(tenant_id, {}) + if name not in scope: + raise ValueError(f"Unknown prompt: {name}") + del scope[name] + if not scope and tenant_id in self._prompts: + del self._prompts[tenant_id] + async def render_prompt( self, name: str, arguments: dict[str, Any] | None, context: Context[LifespanContextT, RequestT], + *, + tenant_id: str | None = None, ) -> list[Message]: - """Render a prompt by name with arguments.""" - prompt = self.get_prompt(name) + """Render a prompt by name with arguments, optionally scoped to a tenant.""" + prompt = self.get_prompt(name, tenant_id=tenant_id) if not prompt: raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 6bf17376d..6572dff40 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -20,18 +20,33 @@ class ResourceManager: - """Manages MCPServer resources.""" + """Manages MCPServer resources with optional tenant-scoped storage. + + Resources and templates are stored in nested dicts: + ``{tenant_id: {uri_string: Resource}}`` and + ``{tenant_id: {uri_template: ResourceTemplate}}`` respectively. + This allows the same URI to exist independently under different tenants + with O(1) lookups per tenant. When ``tenant_id`` is ``None`` (the default), + entries live in a global scope, preserving backward compatibility + with single-tenant usage. + + Note: This class is not thread-safe. It is designed to run within a + single-threaded async event loop, where all synchronous mutations + execute atomically. Do not share instances across OS threads without + external synchronization. + """ def __init__(self, warn_on_duplicate_resources: bool = True): - self._resources: dict[str, Resource] = {} - self._templates: dict[str, ResourceTemplate] = {} + self._resources: dict[str | None, dict[str, Resource]] = {} + self._templates: dict[str | None, dict[str, ResourceTemplate]] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources - def add_resource(self, resource: Resource) -> Resource: - """Add a resource to the manager. + def add_resource(self, resource: Resource, *, tenant_id: str | None = None) -> Resource: + """Add a resource to the manager, optionally scoped to a tenant. Args: resource: A Resource instance to add + tenant_id: Optional tenant scope for the resource Returns: The added resource. If a resource with the same URI already exists, @@ -45,12 +60,14 @@ def add_resource(self, resource: Resource) -> Resource: "resource_name": resource.name, }, ) - existing = self._resources.get(str(resource.uri)) + scope = self._resources.setdefault(tenant_id, {}) + uri_str = str(resource.uri) + existing = scope.get(uri_str) if existing: if self.warn_on_duplicate_resources: logger.warning(f"Resource already exists: {resource.uri}") return existing - self._resources[str(resource.uri)] = resource + scope[uri_str] = resource return resource def add_template( @@ -64,8 +81,15 @@ def add_template( icons: list[Icon] | None = None, annotations: Annotations | None = None, meta: dict[str, Any] | None = None, + *, + tenant_id: str | None = None, ) -> ResourceTemplate: - """Add a template from a function.""" + """Add a template from a function, optionally scoped to a tenant. + + Returns: + The added template. If a template with the same URI template already + exists, returns the existing template. + """ template = ResourceTemplate.from_function( fn, uri_template=uri_template, @@ -77,20 +101,43 @@ def add_template( annotations=annotations, meta=meta, ) - self._templates[template.uri_template] = template + scope = self._templates.setdefault(tenant_id, {}) + existing = scope.get(template.uri_template) + if existing: + if self.warn_on_duplicate_resources: + logger.warning(f"Resource template already exists: {template.uri_template}") + return existing + scope[template.uri_template] = template return template - async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT]) -> Resource: + def remove_resource(self, uri: AnyUrl | str, *, tenant_id: str | None = None) -> None: + """Remove a resource by URI, optionally scoped to a tenant.""" + uri_str = str(uri) + scope = self._resources.get(tenant_id, {}) + if uri_str not in scope: + raise ValueError(f"Unknown resource: {uri}") + del scope[uri_str] + if not scope and tenant_id in self._resources: + del self._resources[tenant_id] + + async def get_resource( + self, + uri: AnyUrl | str, + context: Context[LifespanContextT, RequestT], + *, + tenant_id: str | None = None, + ) -> Resource: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) # First check concrete resources - if resource := self._resources.get(uri_str): + resource = self._resources.get(tenant_id, {}).get(uri_str) + if resource: return resource - # Then check templates - for template in self._templates.values(): + # Then check templates for this tenant scope + for template in self._templates.get(tenant_id, {}).values(): if params := template.matches(uri_str): try: return await template.create_resource(uri_str, params, context=context) @@ -99,12 +146,14 @@ async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContext raise ValueError(f"Unknown resource: {uri}") - def list_resources(self) -> list[Resource]: - """List all registered resources.""" - logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) - - def list_templates(self) -> list[ResourceTemplate]: - """List all registered templates.""" - logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) + def list_resources(self, *, tenant_id: str | None = None) -> list[Resource]: + """List all registered resources for a given tenant scope.""" + resources = list(self._resources.get(tenant_id, {}).values()) + logger.debug("Listing resources", extra={"count": len(resources)}) + return resources + + def list_templates(self, *, tenant_id: str | None = None) -> list[ResourceTemplate]: + """List all registered templates for a given tenant scope.""" + templates = list(self._templates.get(tenant_id, {}).values()) + logger.debug("Listing templates", extra={"count": len(templates)}) + return templates diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..18dfbb1cb 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -293,14 +293,14 @@ def run( async def _handle_list_tools( self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListToolsResult: - return ListToolsResult(tools=await self.list_tools()) + return ListToolsResult(tools=await self.list_tools(tenant_id=ctx.tenant_id)) async def _handle_call_tool( self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams ) -> CallToolResult: context = Context(request_context=ctx, mcp_server=self) try: - result = await self.call_tool(params.name, params.arguments or {}, context) + result = await self.call_tool(params.name, params.arguments or {}, context, tenant_id=ctx.tenant_id) except MCPError: raise except Exception as e: @@ -326,13 +326,13 @@ async def _handle_call_tool( async def _handle_list_resources( self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListResourcesResult: - return ListResourcesResult(resources=await self.list_resources()) + return ListResourcesResult(resources=await self.list_resources(tenant_id=ctx.tenant_id)) async def _handle_read_resource( self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams ) -> ReadResourceResult: context = Context(request_context=ctx, mcp_server=self) - results = await self.read_resource(params.uri, context) + results = await self.read_resource(params.uri, context, tenant_id=ctx.tenant_id) contents: list[TextResourceContents | BlobResourceContents] = [] for item in results: if isinstance(item.content, bytes): @@ -358,22 +358,24 @@ async def _handle_read_resource( async def _handle_list_resource_templates( self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListResourceTemplatesResult: - return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + return ListResourceTemplatesResult( + resource_templates=await self.list_resource_templates(tenant_id=ctx.tenant_id) + ) async def _handle_list_prompts( self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListPromptsResult: - return ListPromptsResult(prompts=await self.list_prompts()) + return ListPromptsResult(prompts=await self.list_prompts(tenant_id=ctx.tenant_id)) async def _handle_get_prompt( self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams ) -> GetPromptResult: context = Context(request_context=ctx, mcp_server=self) - return await self.get_prompt(params.name, params.arguments, context) + return await self.get_prompt(params.name, params.arguments, context, tenant_id=ctx.tenant_id) - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self, *, tenant_id: str | None = None) -> list[MCPTool]: """List all available tools.""" - tools = self._tool_manager.list_tools() + tools = self._tool_manager.list_tools(tenant_id=tenant_id) return [ MCPTool( name=info.name, @@ -389,17 +391,22 @@ async def list_tools(self) -> list[MCPTool]: ] async def call_tool( - self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None + self, + name: str, + arguments: dict[str, Any], + context: Context[LifespanResultT, Any] | None = None, + *, + tenant_id: str | None = None, ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if context is None: context = Context(mcp_server=self) - return await self._tool_manager.call_tool(name, arguments, context, convert_result=True) + return await self._tool_manager.call_tool(name, arguments, context, convert_result=True, tenant_id=tenant_id) - async def list_resources(self) -> list[MCPResource]: + async def list_resources(self, *, tenant_id: str | None = None) -> list[MCPResource]: """List all available resources.""" - resources = self._resource_manager.list_resources() + resources = self._resource_manager.list_resources(tenant_id=tenant_id) return [ MCPResource( uri=resource.uri, @@ -414,8 +421,8 @@ async def list_resources(self) -> list[MCPResource]: for resource in resources ] - async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() + async def list_resource_templates(self, *, tenant_id: str | None = None) -> list[MCPResourceTemplate]: + templates = self._resource_manager.list_templates(tenant_id=tenant_id) return [ MCPResourceTemplate( uri_template=template.uri_template, @@ -431,13 +438,17 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: ] async def read_resource( - self, uri: AnyUrl | str, context: Context[LifespanResultT, Any] | None = None + self, + uri: AnyUrl | str, + context: Context[LifespanResultT, Any] | None = None, + *, + tenant_id: str | None = None, ) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" if context is None: context = Context(mcp_server=self) try: - resource = await self._resource_manager.get_resource(uri, context) + resource = await self._resource_manager.get_resource(uri, context, tenant_id=tenant_id) except ValueError: raise ResourceError(f"Unknown resource: {uri}") @@ -459,6 +470,8 @@ def add_tool( icons: list[Icon] | None = None, meta: dict[str, Any] | None = None, structured_output: bool | None = None, + *, + tenant_id: str | None = None, ) -> None: """Add a tool to the server. @@ -477,6 +490,7 @@ def add_tool( - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool + tenant_id: Optional tenant scope for the tool """ self._tool_manager.add_tool( fn, @@ -487,18 +501,20 @@ def add_tool( icons=icons, meta=meta, structured_output=structured_output, + tenant_id=tenant_id, ) - def remove_tool(self, name: str) -> None: + def remove_tool(self, name: str, *, tenant_id: str | None = None) -> None: """Remove a tool from the server by name. Args: name: The name of the tool to remove + tenant_id: Optional tenant scope for the tool Raises: ToolError: If the tool does not exist """ - self._tool_manager.remove_tool(name) + self._tool_manager.remove_tool(name, tenant_id=tenant_id) def tool( self, @@ -607,13 +623,14 @@ async def handler( return decorator - def add_resource(self, resource: Resource) -> None: + def add_resource(self, resource: Resource, *, tenant_id: str | None = None) -> None: """Add a resource to the server. Args: resource: A Resource instance to add + tenant_id: Optional tenant scope for the resource """ - self._resource_manager.add_resource(resource) + self._resource_manager.add_resource(resource, tenant_id=tenant_id) def resource( self, @@ -727,13 +744,14 @@ def decorator(fn: _CallableT) -> _CallableT: return decorator - def add_prompt(self, prompt: Prompt) -> None: + def add_prompt(self, prompt: Prompt, *, tenant_id: str | None = None) -> None: """Add a prompt to the server. Args: prompt: A Prompt instance to add + tenant_id: Optional tenant scope for the prompt """ - self._prompt_manager.add_prompt(prompt) + self._prompt_manager.add_prompt(prompt, tenant_id=tenant_id) def prompt( self, @@ -1060,9 +1078,9 @@ def streamable_http_app( debug=self.settings.debug, ) - async def list_prompts(self) -> list[MCPPrompt]: + async def list_prompts(self, *, tenant_id: str | None = None) -> list[MCPPrompt]: """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() + prompts = self._prompt_manager.list_prompts(tenant_id=tenant_id) return [ MCPPrompt( name=prompt.name, @@ -1082,13 +1100,18 @@ async def list_prompts(self) -> list[MCPPrompt]: ] async def get_prompt( - self, name: str, arguments: dict[str, Any] | None = None, context: Context[LifespanResultT, Any] | None = None + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[LifespanResultT, Any] | None = None, + *, + tenant_id: str | None = None, ) -> GetPromptResult: """Get a prompt by name with arguments.""" if context is None: context = Context(mcp_server=self) try: - prompt = self._prompt_manager.get_prompt(name) + prompt = self._prompt_manager.get_prompt(name, tenant_id=tenant_id) if not prompt: raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed54797..c1252696e 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -16,7 +16,19 @@ class ToolManager: - """Manages MCPServer tools.""" + """Manages MCPServer tools with optional tenant-scoped storage. + + Tools are stored in a nested dict: ``{tenant_id: {tool_name: Tool}}``. + This allows the same tool name to exist independently under different + tenants with O(1) lookups per tenant. When ``tenant_id`` is ``None`` + (the default), tools live in a global scope, preserving backward + compatibility with single-tenant usage. + + Note: This class is not thread-safe. It is designed to run within a + single-threaded async event loop, where all synchronous mutations + execute atomically. Do not share instances across OS threads without + external synchronization. + """ def __init__( self, @@ -24,22 +36,23 @@ def __init__( *, tools: list[Tool] | None = None, ): - self._tools: dict[str, Tool] = {} + self._tools: dict[str | None, dict[str, Tool]] = {} if tools is not None: + scope = self._tools.setdefault(None, {}) for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: + if warn_on_duplicate_tools and tool.name in scope: logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool + scope[tool.name] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools - def get_tool(self, name: str) -> Tool | None: - """Get tool by name.""" - return self._tools.get(name) + def get_tool(self, name: str, *, tenant_id: str | None = None) -> Tool | None: + """Get tool by name, optionally scoped to a tenant.""" + return self._tools.get(tenant_id, {}).get(name) - def list_tools(self) -> list[Tool]: - """List all registered tools.""" - return list(self._tools.values()) + def list_tools(self, *, tenant_id: str | None = None) -> list[Tool]: + """List all registered tools for a given tenant scope.""" + return list(self._tools.get(tenant_id, {}).values()) def add_tool( self, @@ -51,8 +64,10 @@ def add_tool( icons: list[Icon] | None = None, meta: dict[str, Any] | None = None, structured_output: bool | None = None, + *, + tenant_id: str | None = None, ) -> Tool: - """Add a tool to the server.""" + """Add a tool to the server, optionally scoped to a tenant.""" tool = Tool.from_function( fn, name=name, @@ -63,19 +78,23 @@ def add_tool( meta=meta, structured_output=structured_output, ) - existing = self._tools.get(tool.name) + scope = self._tools.setdefault(tenant_id, {}) + existing = scope.get(tool.name) if existing: if self.warn_on_duplicate_tools: logger.warning(f"Tool already exists: {tool.name}") return existing - self._tools[tool.name] = tool + scope[tool.name] = tool return tool - def remove_tool(self, name: str) -> None: - """Remove a tool by name.""" - if name not in self._tools: + def remove_tool(self, name: str, *, tenant_id: str | None = None) -> None: + """Remove a tool by name, optionally scoped to a tenant.""" + scope = self._tools.get(tenant_id, {}) + if name not in scope: raise ToolError(f"Unknown tool: {name}") - del self._tools[name] + del scope[name] + if not scope and tenant_id in self._tools: + del self._tools[tenant_id] async def call_tool( self, @@ -83,9 +102,11 @@ async def call_tool( arguments: dict[str, Any], context: Context[LifespanContextT, RequestT], convert_result: bool = False, + *, + tenant_id: str | None = None, ) -> Any: - """Call a tool by name with arguments.""" - tool = self.get_tool(name) + """Call a tool by name with arguments, optionally scoped to a tenant.""" + tool = self.get_tool(name, tenant_id=tenant_id) if not tool: raise ToolError(f"Unknown tool: {name}") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..7297f5255 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -76,6 +76,7 @@ class ServerSession( _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None _experimental_features: ExperimentalServerSessionFeatures | None = None + _tenant_id: str | None = None def __init__( self, @@ -108,6 +109,27 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification] def client_params(self) -> types.InitializeRequestParams | None: return self._client_params + @property + def tenant_id(self) -> str | None: + """Get the tenant_id for this session.""" + return self._tenant_id + + @tenant_id.setter + def tenant_id(self, value: str | None) -> None: + """Set the tenant_id for this session (set-once). + + Once a session is bound to a tenant, the tenant_id cannot be changed. + This prevents accidental tenant reassignment which could be a security issue. + + Raises: + ValueError: If tenant_id is already set to a different value. + """ + if self._tenant_id is not None and value != self._tenant_id: + raise ValueError( + f"Cannot change tenant_id from '{self._tenant_id}' to '{value}': session is already bound to a tenant" + ) + self._tenant_id = value + @property def experimental(self) -> ExperimentalServerSessionFeatures: """Experimental APIs for server→client task operations. diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..cd2063e1b 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -21,6 +21,7 @@ StreamableHTTPServerTransport, ) from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._context import tenant_id_var from mcp.types import INVALID_REQUEST, ErrorData, JSONRPCError if TYPE_CHECKING: @@ -89,6 +90,7 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + self._session_tenants: dict[str, str | None] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +137,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_tenants.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -194,6 +197,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: + # Validate that the requesting tenant matches the session's tenant + session_tenant = self._session_tenants.get(request_mcp_session_id) + request_tenant = tenant_id_var.get() + if session_tenant is not None and request_tenant != session_tenant: + logger.warning("Tenant mismatch for session %s", request_mcp_session_id[:64]) + logger.debug( + "Tenant mismatch detail: session bound to '%s', request from '%s'", + session_tenant, + request_tenant, + ) + error_response = JSONRPCError( + jsonrpc="2.0", + id=None, + error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + ) + response = Response( + content=error_response.model_dump_json(by_alias=True, exclude_unset=True), + status_code=HTTPStatus.NOT_FOUND, + media_type="application/json", + ) + await response(scope, receive, send) + return + transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity @@ -217,6 +243,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S assert http_transport.mcp_session_id is not None self._server_instances[http_transport.mcp_session_id] = http_transport + self._session_tenants[http_transport.mcp_session_id] = tenant_id_var.get() logger.info(f"Created new transport with session ID: {new_session_id}") # Define the server runner @@ -246,6 +273,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_tenants.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -260,6 +288,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_tenants.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index bbcee2d02..3b4f5967a 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -1,5 +1,6 @@ """Request context for MCP handlers.""" +import contextvars from dataclasses import dataclass from typing import Any, Generic @@ -8,6 +9,11 @@ from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParamsMeta +# Transport-agnostic contextvar for tenant identification. +# Set by the transport layer (e.g., AuthContextMiddleware for HTTP+OAuth). +# Read by the core server to populate RequestContext.tenant_id. +tenant_id_var = contextvars.ContextVar[str | None]("tenant_id", default=None) + SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) @@ -17,8 +23,13 @@ class RequestContext(Generic[SessionT]): For request handlers, request_id is always populated. For notification handlers, request_id is None. + + The tenant_id field is used in multi-tenant server deployments to identify + which tenant the request belongs to. It is populated from session context + and enables tenant-specific request handling and isolation. """ session: SessionT request_id: RequestId | None = None meta: RequestParamsMeta | None = None + tenant_id: str | None = None diff --git a/tests/client/test_config.py b/tests/client/test_config.py index d1a0576ff..94cea4f22 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -44,7 +44,7 @@ def test_command_execution(mock_config_path: Path): test_args = [command] + args + ["--help"] - result = subprocess.run(test_args, capture_output=True, text=True, timeout=20, check=False) + result = subprocess.run(test_args, capture_output=True, text=True, timeout=60, check=False) assert result.returncode == 0 assert "usage" in result.stdout.lower() diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index c8fc41fd5..e68e6e0f2 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -1,5 +1,7 @@ """Tests for InMemoryTransport.""" +import sys + import pytest from mcp import Client, types @@ -81,6 +83,8 @@ async def test_list_tools(mcpserver_server: MCPServer): assert "greet" in tool_names +@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning" if sys.platform == "win32" else "default") async def test_call_tool(mcpserver_server: MCPServer): """Test calling a tool through the transport.""" async with Client(mcpserver_server) as client: diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 66481bcf7..74736ef5b 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -9,9 +9,11 @@ AuthContextMiddleware, auth_context_var, get_access_token, + get_tenant_id, ) from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.server.auth.provider import AccessToken +from mcp.shared._context import tenant_id_var class MockApp: @@ -117,3 +119,129 @@ async def send(message: Message) -> None: # pragma: no cover # Verify context is still empty after middleware assert auth_context_var.get() is None assert get_access_token() is None + + +@pytest.fixture +def access_token_with_tenant() -> AccessToken: + """Create an access token with a tenant_id.""" + return AccessToken( + token="tenant_token", + client_id="test_client", + scopes=["read", "write"], + expires_at=int(time.time()) + 3600, + tenant_id="tenant-abc", + ) + + +def test_get_tenant_id_without_auth_context(): + """Test get_tenant_id returns None when no auth context exists.""" + assert auth_context_var.get() is None + assert get_tenant_id() is None + + +@pytest.mark.anyio +async def test_get_tenant_id_with_tenant(access_token_with_tenant: AccessToken): + """Test get_tenant_id returns tenant_id when auth context has a tenant.""" + user = AuthenticatedUser(access_token_with_tenant) + scope: Scope = {"type": "http", "user": user} + + tenant_id_during_call: str | None = None + + class TenantCheckApp: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + nonlocal tenant_id_during_call + tenant_id_during_call = get_tenant_id() + + middleware = AuthContextMiddleware(TenantCheckApp()) + + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + await middleware(scope, receive, send) + + assert tenant_id_during_call == "tenant-abc" + # Verify context is reset after middleware + assert get_tenant_id() is None + + +@pytest.mark.anyio +async def test_middleware_sets_tenant_id_var(access_token_with_tenant: AccessToken): + """Test AuthContextMiddleware populates the transport-agnostic tenant_id_var.""" + user = AuthenticatedUser(access_token_with_tenant) + scope: Scope = {"type": "http", "user": user} + + observed_tenant_id: str | None = None + + class CheckApp: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + nonlocal observed_tenant_id + observed_tenant_id = tenant_id_var.get() + + middleware = AuthContextMiddleware(CheckApp()) + + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + await middleware(scope, receive, send) + + assert observed_tenant_id == "tenant-abc" + # Verify contextvar is reset after middleware + assert tenant_id_var.get() is None + + +@pytest.mark.anyio +async def test_middleware_sets_tenant_id_var_none_without_tenant(valid_access_token: AccessToken): + """Test AuthContextMiddleware sets tenant_id_var to None when token has no tenant.""" + user = AuthenticatedUser(valid_access_token) + scope: Scope = {"type": "http", "user": user} + + observed_tenant_id: str | None = "sentinel" + + class CheckApp: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + nonlocal observed_tenant_id + observed_tenant_id = tenant_id_var.get() + + middleware = AuthContextMiddleware(CheckApp()) + + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + await middleware(scope, receive, send) + + assert observed_tenant_id is None + + +@pytest.mark.anyio +async def test_get_tenant_id_without_tenant(valid_access_token: AccessToken): + """Test get_tenant_id returns None when auth context has no tenant.""" + tenant_id_during_call: str | None = "not-none" + + class TenantCheckApp: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + nonlocal tenant_id_during_call + tenant_id_during_call = get_tenant_id() + + middleware = AuthContextMiddleware(TenantCheckApp()) + + user = AuthenticatedUser(valid_access_token) + scope: Scope = {"type": "http", "user": user} + + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + await middleware(scope, receive, send) + + assert tenant_id_during_call is None diff --git a/tests/server/auth/test_multi_tenancy_tokens.py b/tests/server/auth/test_multi_tenancy_tokens.py new file mode 100644 index 000000000..a3764f316 --- /dev/null +++ b/tests/server/auth/test_multi_tenancy_tokens.py @@ -0,0 +1,170 @@ +"""Tests for multi-tenancy support in authentication token models.""" + +import pytest +from pydantic import AnyUrl + +from mcp.server.auth.provider import AccessToken, AuthorizationCode, RefreshToken + + +def test_authorization_code_with_tenant_id(): + """Test AuthorizationCode creation with tenant_id.""" + code = AuthorizationCode( + code="test_code", + scopes=["read", "write"], + expires_at=1234567890.0, + client_id="test_client", + code_challenge="challenge123", + redirect_uri=AnyUrl("http://localhost:8000/callback"), + redirect_uri_provided_explicitly=True, + tenant_id="tenant-abc", + ) + assert code.tenant_id == "tenant-abc" + assert code.code == "test_code" + assert code.scopes == ["read", "write"] + + +def test_authorization_code_without_tenant_id(): + """Test AuthorizationCode backward compatibility without tenant_id.""" + code = AuthorizationCode( + code="test_code", + scopes=["read"], + expires_at=1234567890.0, + client_id="test_client", + code_challenge="challenge123", + redirect_uri=AnyUrl("http://localhost:8000/callback"), + redirect_uri_provided_explicitly=False, + ) + assert code.tenant_id is None + + +def test_authorization_code_serialization_with_tenant_id(): + """Test AuthorizationCode serialization includes tenant_id.""" + code = AuthorizationCode( + code="test_code", + scopes=["read"], + expires_at=1234567890.0, + client_id="test_client", + code_challenge="challenge123", + redirect_uri=AnyUrl("http://localhost:8000/callback"), + redirect_uri_provided_explicitly=True, + tenant_id="tenant-xyz", + ) + data = code.model_dump() + assert data["tenant_id"] == "tenant-xyz" + + # Verify deserialization + restored = AuthorizationCode.model_validate(data) + assert restored.tenant_id == "tenant-xyz" + + +def test_refresh_token_with_tenant_id(): + """Test RefreshToken creation with tenant_id.""" + token = RefreshToken( + token="refresh_token_123", + client_id="test_client", + scopes=["read", "write"], + tenant_id="tenant-abc", + ) + assert token.tenant_id == "tenant-abc" + assert token.token == "refresh_token_123" + + +def test_refresh_token_without_tenant_id(): + """Test RefreshToken backward compatibility without tenant_id.""" + token = RefreshToken( + token="refresh_token_123", + client_id="test_client", + scopes=["read"], + ) + assert token.tenant_id is None + + +def test_refresh_token_serialization_with_tenant_id(): + """Test RefreshToken serialization includes tenant_id.""" + token = RefreshToken( + token="refresh_token_123", + client_id="test_client", + scopes=["read"], + expires_at=1234567890, + tenant_id="tenant-xyz", + ) + data = token.model_dump() + assert data["tenant_id"] == "tenant-xyz" + + # Verify deserialization + restored = RefreshToken.model_validate(data) + assert restored.tenant_id == "tenant-xyz" + + +def test_access_token_with_tenant_id(): + """Test AccessToken creation with tenant_id.""" + token = AccessToken( + token="access_token_123", + client_id="test_client", + scopes=["read", "write"], + tenant_id="tenant-abc", + ) + assert token.tenant_id == "tenant-abc" + assert token.token == "access_token_123" + + +def test_access_token_without_tenant_id(): + """Test AccessToken backward compatibility without tenant_id.""" + token = AccessToken( + token="access_token_123", + client_id="test_client", + scopes=["read"], + ) + assert token.tenant_id is None + + +def test_access_token_serialization_with_tenant_id(): + """Test AccessToken serialization includes tenant_id.""" + token = AccessToken( + token="access_token_123", + client_id="test_client", + scopes=["read"], + expires_at=1234567890, + resource="https://api.example.com", + tenant_id="tenant-xyz", + ) + data = token.model_dump() + assert data["tenant_id"] == "tenant-xyz" + + # Verify deserialization + restored = AccessToken.model_validate(data) + assert restored.tenant_id == "tenant-xyz" + + +def test_access_token_with_resource_and_tenant_id(): + """Test AccessToken with both resource (RFC 8707) and tenant_id.""" + token = AccessToken( + token="access_token_123", + client_id="test_client", + scopes=["read"], + resource="https://api.example.com", + tenant_id="tenant-abc", + ) + assert token.resource == "https://api.example.com" + assert token.tenant_id == "tenant-abc" + + +@pytest.mark.parametrize( + "tenant_id", + [ + "tenant-123", + "org_abc_def", + "a" * 100, # Long tenant ID + "tenant-with-dashes", + "tenant.with.dots", + ], +) +def test_access_token_various_tenant_id_formats(tenant_id: str): + """Test AccessToken accepts various tenant_id formats.""" + token = AccessToken( + token="access_token_123", + client_id="test_client", + scopes=["read"], + tenant_id=tenant_id, + ) + assert token.tenant_id == tenant_id diff --git a/tests/server/mcpserver/conftest.py b/tests/server/mcpserver/conftest.py new file mode 100644 index 000000000..993f91250 --- /dev/null +++ b/tests/server/mcpserver/conftest.py @@ -0,0 +1,22 @@ +from collections.abc import Callable +from typing import Any + +import pytest + +from mcp.server.mcpserver.context import Context + +MakeContext = Callable[..., Context[Any, Any]] + + +@pytest.fixture +def make_context() -> MakeContext: + """Factory fixture for creating Context instances in tests. + + Centralizes Context construction so that tests don't break if the + Context.__init__ signature changes in later iterations. + """ + + def _make(**kwargs: Any) -> Context[Any, Any]: + return Context(**kwargs) + + return _make diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index 724b57997..f663f9c94 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -4,8 +4,8 @@ import pytest from pydantic import AnyUrl -from mcp.server.mcpserver import Context -from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate +from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager +from tests.server.mcpserver.conftest import MakeContext @pytest.fixture @@ -78,7 +78,7 @@ def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pyte assert "Resource already exists" not in caplog.text @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path): + async def test_get_resource(self, temp_file: Path, make_context: MakeContext): """Test getting a resource by URI.""" manager = ResourceManager() resource = FileResource( @@ -87,35 +87,30 @@ async def test_get_resource(self, temp_file: Path): path=temp_file, ) manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri, Context()) + retrieved = await manager.get_resource(resource.uri, make_context()) assert retrieved == resource @pytest.mark.anyio - async def test_get_resource_from_template(self): + async def test_get_resource_from_template(self, make_context: MakeContext): """Test getting a resource through a template.""" manager = ResourceManager() def greet(name: str) -> str: return f"Hello, {name}!" - template = ResourceTemplate.from_function( - fn=greet, - uri_template="greet://{name}", - name="greeter", - ) - manager._templates[template.uri_template] = template + manager.add_template(fn=greet, uri_template="greet://{name}", name="greeter") - resource = await manager.get_resource(AnyUrl("greet://world"), Context()) + resource = await manager.get_resource(AnyUrl("greet://world"), make_context()) assert isinstance(resource, FunctionResource) content = await resource.read() assert content == "Hello, world!" @pytest.mark.anyio - async def test_get_unknown_resource(self): + async def test_get_unknown_resource(self, make_context: MakeContext): """Test getting a non-existent resource.""" manager = ResourceManager() with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test"), Context()) + await manager.get_resource(AnyUrl("unknown://test"), make_context()) def test_list_resources(self, temp_file: Path): """Test listing all resources.""" @@ -175,3 +170,37 @@ def get_item(id: str) -> str: # pragma: no cover ) assert template.meta is None + + def test_add_duplicate_template(self): + """Test adding the same template twice returns the existing one.""" + manager = ResourceManager() + + def get_item(id: str) -> str: # pragma: no cover + return f"Item {id}" + + first = manager.add_template(fn=get_item, uri_template="resource://items/{id}") + second = manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert first is second + assert len(manager.list_templates()) == 1 + + def test_warn_on_duplicate_template(self, caplog: pytest.LogCaptureFixture): + """Test warning on duplicate template.""" + manager = ResourceManager() + + def get_item(id: str) -> str: # pragma: no cover + return f"Item {id}" + + manager.add_template(fn=get_item, uri_template="resource://items/{id}") + manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert "Resource template already exists" in caplog.text + + def test_disable_warn_on_duplicate_template(self, caplog: pytest.LogCaptureFixture): + """Test disabling warning on duplicate template.""" + manager = ResourceManager(warn_on_duplicate_resources=False) + + def get_item(id: str) -> str: # pragma: no cover + return f"Item {id}" + + manager.add_template(fn=get_item, uri_template="resource://items/{id}") + manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert "Resource template already exists" not in caplog.text diff --git a/tests/server/mcpserver/test_multi_tenancy_e2e.py b/tests/server/mcpserver/test_multi_tenancy_e2e.py new file mode 100644 index 000000000..2d7795b4d --- /dev/null +++ b/tests/server/mcpserver/test_multi_tenancy_e2e.py @@ -0,0 +1,336 @@ +"""End-to-end tests for multi-tenant isolation. + +These tests exercise the full tenant isolation stack using the in-memory +transport and the high-level ``Client`` class. They verify that: + +1. Tools, resources, and prompts registered under one tenant are invisible + to other tenants and to the global (None) scope. +2. ``Context.tenant_id`` is correctly populated inside tool handlers. +3. Backward compatibility is preserved — everything works without tenant_id. +""" + +from __future__ import annotations + +import anyio +import pytest + +from mcp import Client +from mcp.client.session import ClientSession +from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.prompts.base import Prompt +from mcp.server.mcpserver.resources.types import FunctionResource +from mcp.shared._context import tenant_id_var +from mcp.shared.memory import create_client_server_memory_streams +from mcp.types import TextContent + +pytestmark = pytest.mark.anyio + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_multi_tenant_server() -> MCPServer: + """Build an MCPServer with tenant-scoped tools, resources, and prompts.""" + server = MCPServer("multi-tenant-test") + + # Tenant-A tools / resources / prompts + def tool_a(x: int) -> str: + return f"tenant-a:{x}" + + server.add_tool(tool_a, name="compute", tenant_id="tenant-a") + server.add_resource( + FunctionResource(uri="data://info", name="info-a", fn=lambda: "secret-a"), + tenant_id="tenant-a", + ) + + async def prompt_a() -> str: + return "Hello from tenant-a" + + server.add_prompt(Prompt.from_function(prompt_a, name="greet"), tenant_id="tenant-a") + + # Tenant-B tools / resources / prompts (same names, different data) + def tool_b(x: int) -> str: + return f"tenant-b:{x}" + + server.add_tool(tool_b, name="compute", tenant_id="tenant-b") + server.add_resource( + FunctionResource(uri="data://info", name="info-b", fn=lambda: "secret-b"), + tenant_id="tenant-b", + ) + + async def prompt_b() -> str: + return "Hello from tenant-b" + + server.add_prompt(Prompt.from_function(prompt_b, name="greet"), tenant_id="tenant-b") + + return server + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_tenant_a_sees_only_own_tools(): + """Tenant-A's client lists only tenant-A's tools.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server() -> None: + token = tenant_id_var.set("tenant-a") + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + tools = await session.list_tools() + assert len(tools.tools) == 1 + assert tools.tools[0].name == "compute" + + tg.cancel_scope.cancel() + + +async def test_tenant_b_sees_only_own_tools(): + """Tenant-B's client lists only tenant-B's tools.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server() -> None: + token = tenant_id_var.set("tenant-b") + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + tools = await session.list_tools() + assert len(tools.tools) == 1 + assert tools.tools[0].name == "compute" + + tg.cancel_scope.cancel() + + +async def test_global_scope_sees_nothing_when_all_tenant_scoped(): + """With no tenant context, no tools/resources/prompts are visible.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + ) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + + tools = await session.list_tools() + assert len(tools.tools) == 0 + + resources = await session.list_resources() + assert len(resources.resources) == 0 + + prompts = await session.list_prompts() + assert len(prompts.prompts) == 0 + + tg.cancel_scope.cancel() + + +async def test_tenant_tool_returns_correct_result(): + """Calling a tenant-scoped tool returns the correct tenant's result.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server() -> None: + token = tenant_id_var.set("tenant-a") + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + result = await session.call_tool("compute", {"x": 42}) + texts = [c.text for c in result.content if isinstance(c, TextContent)] + assert any("tenant-a:42" in t for t in texts) + + tg.cancel_scope.cancel() + + +async def test_tenant_resource_isolation(): + """Tenant-A can read its resource; tenant-B reads a different value.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + for tenant, expected_name in [("tenant-a", "info-a"), ("tenant-b", "info-b")]: + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server(tid: str = tenant) -> None: + token = tenant_id_var.set(tid) + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + resources = await session.list_resources() + assert len(resources.resources) == 1 + assert resources.resources[0].name == expected_name + + tg.cancel_scope.cancel() + + +async def test_tenant_prompt_isolation(): + """Each tenant sees only its own prompts.""" + server = _build_multi_tenant_server() + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + for tenant in ["tenant-a", "tenant-b"]: + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server(tid: str = tenant) -> None: + token = tenant_id_var.set(tid) + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + assert prompts.prompts[0].name == "greet" + + result = await session.get_prompt("greet") + text = result.messages[0].content.text # type: ignore[union-attr] + assert tenant in text + + tg.cancel_scope.cancel() + + +async def test_context_tenant_id_available_in_tool(): + """The ``Context.tenant_id`` property is populated inside a tool handler.""" + captured_tenant: list[str | None] = [] + + server = MCPServer("ctx-test") + + def check_tenant(ctx: Context) -> str: + captured_tenant.append(ctx.tenant_id) + return "ok" + + # Register under the tenant scope that will be active during the test + server.add_tool(check_tenant, name="check_tenant", tenant_id="my-tenant") + actual = server._lowlevel_server # type: ignore[reportPrivateUsage] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + + async def run_server() -> None: + token = tenant_id_var.set("my-tenant") + try: + await actual.run( + server_read, server_write, actual.create_initialization_options(), raise_exceptions=True + ) + finally: + tenant_id_var.reset(token) + + tg.start_soon(run_server) + + async with ClientSession(client_read, client_write) as session: + await session.initialize() + await session.call_tool("check_tenant", {}) + + tg.cancel_scope.cancel() + + assert captured_tenant == ["my-tenant"] + + +async def test_backward_compat_no_tenant(): + """Without tenant_id set, tools/resources/prompts in global scope work normally.""" + server = MCPServer("compat-test") + + @server.tool() + def hello(name: str) -> str: + return f"Hi {name}" + + @server.resource("test://data") + def data() -> str: + return "some data" + + @server.prompt() + def ask() -> str: + return "Please answer" + + async with Client(server) as client: + tools = await client.list_tools() + assert len(tools.tools) == 1 + + result = await client.call_tool("hello", {"name": "World"}) + assert any("Hi World" in c.text for c in result.content if isinstance(c, TextContent)) + + resources = await client.list_resources() + assert len(resources.resources) == 1 + + prompts = await client.list_prompts() + assert len(prompts.prompts) == 1 diff --git a/tests/server/mcpserver/test_multi_tenancy_managers.py b/tests/server/mcpserver/test_multi_tenancy_managers.py new file mode 100644 index 000000000..16e8ffe2e --- /dev/null +++ b/tests/server/mcpserver/test_multi_tenancy_managers.py @@ -0,0 +1,373 @@ +"""Tests for tenant-scoped storage in ToolManager, ResourceManager, and PromptManager.""" + +import pytest + +from mcp.server.mcpserver.exceptions import ToolError +from mcp.server.mcpserver.prompts.base import Prompt +from mcp.server.mcpserver.prompts.manager import PromptManager +from mcp.server.mcpserver.resources.resource_manager import ResourceManager +from mcp.server.mcpserver.resources.types import FunctionResource +from mcp.server.mcpserver.tools import ToolManager +from tests.server.mcpserver.conftest import MakeContext + +# --- ToolManager --- + + +def test_add_tool_with_tenant_id(): + """Tools added under different tenants are isolated.""" + manager = ToolManager() + + def tool_a() -> str: # pragma: no cover + return "a" + + def tool_b() -> str: # pragma: no cover + return "b" + + manager.add_tool(tool_a, name="shared_name", tenant_id="tenant-a") + manager.add_tool(tool_b, name="shared_name", tenant_id="tenant-b") + + assert manager.get_tool("shared_name", tenant_id="tenant-a") is not None + assert manager.get_tool("shared_name", tenant_id="tenant-b") is not None + # Different tool objects despite same name + assert manager.get_tool("shared_name", tenant_id="tenant-a") is not manager.get_tool( + "shared_name", tenant_id="tenant-b" + ) + + +def test_list_tools_filtered_by_tenant(): + """list_tools only returns tools for the requested tenant.""" + manager = ToolManager() + + def fa() -> str: # pragma: no cover + return "a" + + def fb() -> str: # pragma: no cover + return "b" + + def fc() -> str: # pragma: no cover + return "c" + + manager.add_tool(fa, tenant_id="tenant-a") + manager.add_tool(fb, tenant_id="tenant-b") + manager.add_tool(fc) # global (None tenant) + + assert len(manager.list_tools(tenant_id="tenant-a")) == 1 + assert len(manager.list_tools(tenant_id="tenant-b")) == 1 + assert len(manager.list_tools()) == 1 # global only + + +def test_get_tool_wrong_tenant_returns_none(): + """A tool registered under tenant-a is not visible to tenant-b.""" + manager = ToolManager() + + def my_tool() -> str: # pragma: no cover + return "x" + + manager.add_tool(my_tool, tenant_id="tenant-a") + + assert manager.get_tool("my_tool", tenant_id="tenant-a") is not None + assert manager.get_tool("my_tool", tenant_id="tenant-b") is None + assert manager.get_tool("my_tool") is None # global scope + + +def test_remove_tool_with_tenant(): + """remove_tool respects tenant scope.""" + manager = ToolManager() + + def my_tool() -> str: # pragma: no cover + return "x" + + manager.add_tool(my_tool, tenant_id="tenant-a") + manager.add_tool(my_tool, name="my_tool", tenant_id="tenant-b") + + manager.remove_tool("my_tool", tenant_id="tenant-a") + + assert manager.get_tool("my_tool", tenant_id="tenant-a") is None + assert manager.get_tool("my_tool", tenant_id="tenant-b") is not None + # Empty tenant scope is cleaned up + assert "tenant-a" not in manager._tools + + +def test_remove_tool_wrong_tenant_raises(): + """Removing a tool under the wrong tenant raises ToolError.""" + manager = ToolManager() + + def my_tool() -> str: # pragma: no cover + return "x" + + manager.add_tool(my_tool, tenant_id="tenant-a") + + with pytest.raises(ToolError): + manager.remove_tool("my_tool", tenant_id="tenant-b") + + +@pytest.mark.anyio +async def test_call_tool_with_tenant(make_context: MakeContext): + """call_tool respects tenant scope.""" + manager = ToolManager() + + def tool_a() -> str: + return "result-a" + + def tool_b() -> str: + return "result-b" + + manager.add_tool(tool_a, name="do_work", tenant_id="tenant-a") + manager.add_tool(tool_b, name="do_work", tenant_id="tenant-b") + + result_a = await manager.call_tool("do_work", {}, make_context(), tenant_id="tenant-a") + result_b = await manager.call_tool("do_work", {}, make_context(), tenant_id="tenant-b") + + assert result_a == "result-a" + assert result_b == "result-b" + + +@pytest.mark.anyio +async def test_call_tool_wrong_tenant_raises(make_context: MakeContext): + """Calling a tool under the wrong tenant raises ToolError.""" + manager = ToolManager() + + def my_tool() -> str: # pragma: no cover + return "x" + + manager.add_tool(my_tool, tenant_id="tenant-a") + + with pytest.raises(ToolError): + await manager.call_tool("my_tool", {}, make_context(), tenant_id="tenant-b") + + +# --- ResourceManager --- + + +def _make_resource(uri: str, name: str) -> FunctionResource: + """Helper to create a concrete resource.""" + return FunctionResource(uri=uri, name=name, fn=lambda: name) + + +def test_add_resource_with_tenant(): + """Resources added under different tenants are isolated.""" + manager = ResourceManager() + + resource_a = _make_resource("file:///data", "data-a") + resource_b = _make_resource("file:///data", "data-b") + + added_a = manager.add_resource(resource_a, tenant_id="tenant-a") + added_b = manager.add_resource(resource_b, tenant_id="tenant-b") + + assert added_a.name == "data-a" + assert added_b.name == "data-b" + + +def test_list_resources_filtered_by_tenant(): + """list_resources only returns resources for the requested tenant.""" + manager = ResourceManager() + + manager.add_resource(_make_resource("file:///a", "a"), tenant_id="tenant-a") + manager.add_resource(_make_resource("file:///b", "b"), tenant_id="tenant-b") + manager.add_resource(_make_resource("file:///g", "global")) + + assert len(manager.list_resources(tenant_id="tenant-a")) == 1 + assert len(manager.list_resources(tenant_id="tenant-b")) == 1 + assert len(manager.list_resources()) == 1 + + +def test_add_template_with_tenant(): + """Templates added under different tenants are isolated.""" + manager = ResourceManager() + + def greet_a(name: str) -> str: # pragma: no cover + return f"Hello from A, {name}!" + + def greet_b(name: str) -> str: # pragma: no cover + return f"Hello from B, {name}!" + + manager.add_template(greet_a, uri_template="greet://{name}", tenant_id="tenant-a") + manager.add_template(greet_b, uri_template="greet://{name}", tenant_id="tenant-b") + + assert len(manager.list_templates(tenant_id="tenant-a")) == 1 + assert len(manager.list_templates(tenant_id="tenant-b")) == 1 + assert len(manager.list_templates()) == 0 # no global templates + + +@pytest.mark.anyio +async def test_get_resource_respects_tenant(make_context: MakeContext): + """get_resource only finds resources in the correct tenant scope.""" + manager = ResourceManager() + + resource = _make_resource("file:///secret", "secret") + manager.add_resource(resource, tenant_id="tenant-a") + + # Tenant A can access + found = await manager.get_resource("file:///secret", make_context(), tenant_id="tenant-a") + assert found.name == "secret" + + # Tenant B cannot + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource("file:///secret", make_context(), tenant_id="tenant-b") + + # Global scope cannot + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource("file:///secret", make_context()) + + +@pytest.mark.anyio +async def test_get_resource_from_template_respects_tenant(make_context: MakeContext): + """Template-based resource creation respects tenant scope.""" + manager = ResourceManager() + + def greet(name: str) -> str: + return f"Hello, {name}!" + + manager.add_template(greet, uri_template="greet://{name}", tenant_id="tenant-a") + + # Tenant A can resolve + resource = await manager.get_resource("greet://world", make_context(), tenant_id="tenant-a") + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + # Tenant B cannot + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource("greet://world", make_context(), tenant_id="tenant-b") + + +def test_remove_resource_with_tenant(): + """remove_resource respects tenant scope.""" + manager = ResourceManager() + + manager.add_resource(_make_resource("file:///data", "data"), tenant_id="tenant-a") + manager.add_resource(_make_resource("file:///data", "data"), tenant_id="tenant-b") + + manager.remove_resource("file:///data", tenant_id="tenant-a") + + assert len(manager.list_resources(tenant_id="tenant-a")) == 0 + assert len(manager.list_resources(tenant_id="tenant-b")) == 1 + # Empty tenant scope is cleaned up + assert "tenant-a" not in manager._resources + + +def test_remove_resource_partial_tenant_scope(): + """Removing one resource leaves the tenant scope intact when others remain.""" + manager = ResourceManager() + + manager.add_resource(_make_resource("file:///a", "a"), tenant_id="tenant-a") + manager.add_resource(_make_resource("file:///b", "b"), tenant_id="tenant-a") + + manager.remove_resource("file:///a", tenant_id="tenant-a") + + assert len(manager.list_resources(tenant_id="tenant-a")) == 1 + assert "tenant-a" in manager._resources + + +def test_remove_resource_wrong_tenant_raises(): + """Removing a resource under the wrong tenant raises ValueError.""" + manager = ResourceManager() + manager.add_resource(_make_resource("file:///data", "data"), tenant_id="tenant-a") + + with pytest.raises(ValueError, match="Unknown resource"): + manager.remove_resource("file:///data", tenant_id="tenant-b") + + +# --- PromptManager --- + + +def _make_prompt(name: str, text: str) -> Prompt: + """Helper to create a simple prompt.""" + + async def fn() -> str: # pragma: no cover + return text + + return Prompt.from_function(fn, name=name) + + +def test_add_prompt_with_tenant(): + """Prompts added under different tenants are isolated.""" + manager = PromptManager() + + prompt_a = _make_prompt("greet", "Hello from A") + prompt_b = _make_prompt("greet", "Hello from B") + + manager.add_prompt(prompt_a, tenant_id="tenant-a") + manager.add_prompt(prompt_b, tenant_id="tenant-b") + + assert manager.get_prompt("greet", tenant_id="tenant-a") is prompt_a + assert manager.get_prompt("greet", tenant_id="tenant-b") is prompt_b + + +def test_list_prompts_filtered_by_tenant(): + """list_prompts only returns prompts for the requested tenant.""" + manager = PromptManager() + + manager.add_prompt(_make_prompt("a", "A"), tenant_id="tenant-a") + manager.add_prompt(_make_prompt("b", "B"), tenant_id="tenant-b") + manager.add_prompt(_make_prompt("g", "Global")) + + assert len(manager.list_prompts(tenant_id="tenant-a")) == 1 + assert len(manager.list_prompts(tenant_id="tenant-b")) == 1 + assert len(manager.list_prompts()) == 1 + + +def test_get_prompt_wrong_tenant_returns_none(): + """A prompt registered under tenant-a is not visible to tenant-b.""" + manager = PromptManager() + manager.add_prompt(_make_prompt("secret", "x"), tenant_id="tenant-a") + + assert manager.get_prompt("secret", tenant_id="tenant-a") is not None + assert manager.get_prompt("secret", tenant_id="tenant-b") is None + assert manager.get_prompt("secret") is None + + +@pytest.mark.anyio +async def test_render_prompt_respects_tenant(make_context: MakeContext): + """render_prompt only finds prompts in the correct tenant scope.""" + manager = PromptManager() + + async def greet() -> str: + return "Hello from tenant-a" + + manager.add_prompt(Prompt.from_function(greet, name="greet"), tenant_id="tenant-a") + + # Tenant A can render + messages = await manager.render_prompt("greet", None, make_context(), tenant_id="tenant-a") + assert len(messages) > 0 + + # Tenant B cannot + with pytest.raises(ValueError, match="Unknown prompt"): + await manager.render_prompt("greet", None, make_context(), tenant_id="tenant-b") + + +def test_remove_prompt_with_tenant(): + """remove_prompt respects tenant scope.""" + manager = PromptManager() + + manager.add_prompt(_make_prompt("greet", "A"), tenant_id="tenant-a") + manager.add_prompt(_make_prompt("greet", "B"), tenant_id="tenant-b") + + manager.remove_prompt("greet", tenant_id="tenant-a") + + assert manager.get_prompt("greet", tenant_id="tenant-a") is None + assert manager.get_prompt("greet", tenant_id="tenant-b") is not None + # Empty tenant scope is cleaned up + assert "tenant-a" not in manager._prompts + + +def test_remove_prompt_partial_tenant_scope(): + """Removing one prompt leaves the tenant scope intact when others remain.""" + manager = PromptManager() + + manager.add_prompt(_make_prompt("greet", "A"), tenant_id="tenant-a") + manager.add_prompt(_make_prompt("farewell", "B"), tenant_id="tenant-a") + + manager.remove_prompt("greet", tenant_id="tenant-a") + + assert len(manager.list_prompts(tenant_id="tenant-a")) == 1 + assert "tenant-a" in manager._prompts + + +def test_remove_prompt_wrong_tenant_raises(): + """Removing a prompt under the wrong tenant raises ValueError.""" + manager = PromptManager() + manager.add_prompt(_make_prompt("greet", "A"), tenant_id="tenant-a") + + with pytest.raises(ValueError, match="Unknown prompt"): + manager.remove_prompt("greet", tenant_id="tenant-b") diff --git a/tests/server/mcpserver/test_multi_tenancy_oauth_e2e.py b/tests/server/mcpserver/test_multi_tenancy_oauth_e2e.py new file mode 100644 index 000000000..08dac1369 --- /dev/null +++ b/tests/server/mcpserver/test_multi_tenancy_oauth_e2e.py @@ -0,0 +1,422 @@ +"""End-to-end test for multi-tenant isolation through the OAuth + HTTP stack. + +This test exercises the full production path that a real deployment would use: + + 1. Client sends HTTP request with ``Authorization: Bearer `` + 2. ``AuthContextMiddleware`` validates the token via ``TokenVerifier``, + extracts ``AccessToken.tenant_id``, and sets the ``tenant_id_var`` + contextvar for the duration of the request. + 3. ``StreamableHTTPSessionManager`` binds new sessions to the current + tenant and rejects cross-tenant session access. + 4. The low-level ``Server._handle_request`` reads ``tenant_id_var`` and + populates ``ServerRequestContext.tenant_id``. + 5. ``MCPServer`` handlers (e.g. ``_handle_list_tools``) pass + ``ctx.tenant_id`` to the appropriate manager, which returns only + the items registered under that tenant. + 6. The client sees only its own tenant's tools/resources/prompts. + +Unlike the in-memory E2E tests in ``test_multi_tenancy_e2e.py`` that set +``tenant_id_var`` manually, this test uses a real Starlette app with auth +middleware and HTTP transport to verify the full integration — proving that +tenant_id flows correctly from the OAuth token all the way through to the +handler response. + +Key complexity notes: + - We use ``StubTokenVerifier`` instead of a full OAuth provider because + the MCP auth stack allows plugging in a custom ``TokenVerifier``. This + lets us skip the OAuth authorization code flow while still exercising + the real ``AuthContextMiddleware`` → ``tenant_id_var`` path. + - ``httpx.ASGITransport`` does NOT send ASGI lifespan events, so + Starlette's lifespan (which starts ``StreamableHTTPSessionManager.run()``) + never fires. We work around this with ``_start_lifespan()``, which + manually sends the lifespan startup/shutdown events to the ASGI app. +""" + +from __future__ import annotations + +import time +from collections.abc import AsyncIterator, MutableMapping +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import httpx +import pytest +from pydantic import AnyHttpUrl +from starlette.applications import Starlette + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.server import MCPServer +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import TextContent + +pytestmark = pytest.mark.anyio + + +# --------------------------------------------------------------------------- +# Stub token verifier — maps bearer tokens to AccessTokens with tenant_id +# --------------------------------------------------------------------------- + + +class StubTokenVerifier(TokenVerifier): + """Token verifier that recognises hard-coded bearer tokens. + + In production, ``TokenVerifier.verify_token()`` would call an OAuth + introspection endpoint or decode a JWT. Here we simply look up the + token in a pre-built dict, returning the corresponding ``AccessToken`` + (which includes ``tenant_id``). This is the minimal surface needed to + exercise the real auth middleware without a full OAuth server. + """ + + def __init__(self, token_map: dict[str, AccessToken]) -> None: + self._tokens = token_map + + async def verify_token(self, token: str) -> AccessToken | None: + # Returns None for unknown tokens, which the auth middleware + # treats as an authentication failure (HTTP 401). + return self._tokens.get(token) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@asynccontextmanager +async def _start_lifespan(app: httpx._transports.asgi._ASGIApp) -> AsyncIterator[None]: + """Manually trigger ASGI lifespan startup/shutdown on a Starlette app. + + Why this is needed: + ``httpx.ASGITransport`` sends only HTTP request events — it does NOT + send ASGI lifespan events. However, the Starlette app returned by + ``MCPServer.streamable_http_app()`` has a lifespan handler that starts + ``StreamableHTTPSessionManager.run()``. Without lifespan startup, the + session manager's internal task group is never initialised, and any + HTTP request that tries to create a session will fail with: + RuntimeError: Task group is not initialized. Make sure to use run(). + + How it works: + We call the ASGI app directly with a ``lifespan`` scope and provide + custom ``receive``/``send`` callables that simulate the ASGI server's + lifespan protocol: + 1. Send ``lifespan.startup`` → app initialises (starts session manager) + 2. Wait for ``lifespan.startup.complete`` from the app + 3. Yield control to the test + 4. On cleanup, send ``lifespan.shutdown`` → app tears down + 5. Wait for ``lifespan.shutdown.complete``, then cancel the task group + """ + # Events to coordinate the lifespan protocol handshake + started = anyio.Event() + shutdown = anyio.Event() + startup_complete = anyio.Event() + shutdown_complete = anyio.Event() + + # ASGI lifespan scope — tells the app this is a lifespan connection + scope = {"type": "lifespan", "asgi": {"version": "3.0"}} + + async def receive() -> dict[str, str]: + """Feed lifespan events to the ASGI app. + + Called twice: once for startup (immediately), once for shutdown + (blocks until the test is done and ``shutdown`` is set). + """ + if not started.is_set(): + started.set() + return {"type": "lifespan.startup"} + # Block here until the test finishes and triggers shutdown + await shutdown.wait() + return {"type": "lifespan.shutdown"} + + async def send(message: MutableMapping[str, Any]) -> None: + """Receive acknowledgements from the ASGI app.""" + if message["type"] == "lifespan.startup.complete": + startup_complete.set() + elif message["type"] == "lifespan.shutdown.complete": + shutdown_complete.set() + + async with anyio.create_task_group() as tg: + # Run the ASGI app's lifespan handler in the background + tg.start_soon(app, scope, receive, send) + # Wait until the app signals that startup is complete + await startup_complete.wait() + try: + yield + finally: + # Signal the app to shut down and wait for confirmation + shutdown.set() + await shutdown_complete.wait() + tg.cancel_scope.cancel() + + +def _build_tenant_server(verifier: StubTokenVerifier) -> MCPServer: + """Create an MCPServer with auth enabled and tenant-scoped tools. + + The server is configured with: + - ``token_verifier``: Our stub that maps bearer tokens to AccessTokens + - ``auth``: AuthSettings that enable the auth middleware stack + (issuer_url and resource_server_url are fake since we bypass OAuth) + + Tools registered: + - "query" under tenant "alpha" — simulates an analytics tool + - "publish" under tenant "beta" — simulates a publishing tool + - "whoami" under both tenants — reads ctx.tenant_id to prove + the tenant context is correctly propagated to handlers + """ + server = MCPServer( + "tenant-oauth-test", + token_verifier=verifier, + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + resource_server_url=AnyHttpUrl("https://mcp.example.com"), + required_scopes=["read"], + ), + ) + + # Tenant "alpha" tools — only visible to requests with tenant_id="alpha" + def alpha_query(sql: str) -> str: + return f"alpha: {sql}" + + server.add_tool(alpha_query, name="query", tenant_id="alpha") + + # Tenant "beta" tools — only visible to requests with tenant_id="beta" + def beta_publish(title: str) -> str: + return f"beta: {title}" + + server.add_tool(beta_publish, name="publish", tenant_id="beta") + + # "whoami" is registered under BOTH tenants (same function, different + # tenant scopes). This lets us verify that ctx.tenant_id is correctly + # set for each tenant's request independently. + def whoami(ctx: Context) -> str: + return f"tenant={ctx.tenant_id}" + + server.add_tool(whoami, name="whoami", tenant_id="alpha") + server.add_tool(whoami, name="whoami", tenant_id="beta") + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def token_map() -> dict[str, AccessToken]: + """Map bearer token strings to AccessToken objects with tenant_id. + + This simulates what a real OAuth token introspection would return: + each bearer token resolves to an AccessToken containing the tenant_id + that identifies which tenant the caller belongs to. + """ + now = int(time.time()) + return { + # "token-alpha" authenticates as tenant "alpha" with scope "read" + "token-alpha": AccessToken( + token="token-alpha", + client_id="client-1", + scopes=["read"], + expires_at=now + 3600, + tenant_id="alpha", + ), + # "token-beta" authenticates as tenant "beta" with scope "read" + "token-beta": AccessToken( + token="token-beta", + client_id="client-2", + scopes=["read"], + expires_at=now + 3600, + tenant_id="beta", + ), + } + + +@pytest.fixture +def verifier(token_map: dict[str, AccessToken]) -> StubTokenVerifier: + return StubTokenVerifier(token_map) + + +@pytest.fixture +def tenant_app(verifier: StubTokenVerifier) -> MCPServer: + return _build_tenant_server(verifier) + + +@pytest.fixture +def starlette_app(tenant_app: MCPServer) -> Starlette: + """Build the Starlette ASGI app with DNS rebinding protection disabled. + + Starlette is the ASGI web framework that MCPServer uses under the hood + for HTTP transport. ``MCPServer.streamable_http_app()`` returns a + Starlette ``Application`` wired with: + - Auth middleware (``AuthenticationMiddleware`` + ``AuthContextMiddleware``) + that validates bearer tokens and sets ``tenant_id_var`` + - A ``StreamableHTTPASGIApp`` route that handles MCP JSON-RPC over HTTP + - A lifespan handler that starts/stops ``StreamableHTTPSessionManager`` + - Transport security middleware for DNS rebinding protection + + In tests we use ``httpx.ASGITransport`` to send requests directly to + this ASGI app in-process (no real network). However, ASGITransport + sends the Host header as just "localhost" without a port, while the + default DNS rebinding protection expects "localhost:". We disable + DNS rebinding protection here since it's not relevant to tenant isolation. + """ + return tenant_app.streamable_http_app( + transport_security=TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_alpha_sees_only_own_tools(starlette_app: Starlette): + """A client authenticating as tenant 'alpha' sees only alpha's tools. + + Verifies the full path: Bearer token-alpha → AuthContextMiddleware + extracts tenant_id="alpha" → ToolManager filters to alpha's tools + → client receives ["query", "whoami"] (not beta's "publish"). + """ + # Start ASGI lifespan to initialise the StreamableHTTPSessionManager + async with _start_lifespan(starlette_app): + # Create an HTTP client that sends requests through ASGITransport + # directly to the Starlette app (no real network involved). + # The Authorization header is included on every request. + http_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), + headers={"Authorization": "Bearer token-alpha"}, + ) + async with http_client: + # Use the MCP streamable HTTP client to establish a session + async with streamable_http_client( + url="http://localhost/mcp", + http_client=http_client, + ) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + # List tools — should only see alpha's tools + tools = await session.list_tools() + tool_names = sorted(t.name for t in tools.tools) + assert tool_names == ["query", "whoami"] + + +async def test_beta_sees_only_own_tools(starlette_app: Starlette): + """A client authenticating as tenant 'beta' sees only beta's tools. + + Same structure as the alpha test, but with token-beta. Verifies that + beta sees ["publish", "whoami"] and NOT alpha's "query" tool. + """ + async with _start_lifespan(starlette_app): + http_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), + headers={"Authorization": "Bearer token-beta"}, + ) + async with http_client: + async with streamable_http_client( + url="http://localhost/mcp", + http_client=http_client, + ) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tools = await session.list_tools() + tool_names = sorted(t.name for t in tools.tools) + assert tool_names == ["publish", "whoami"] + + +async def test_alpha_can_call_own_tool(starlette_app: Starlette): + """Tenant alpha can call its own tool and get the correct result. + + Goes beyond list_tools — actually invokes the "query" tool to verify + that the tool execution path also respects tenant scoping. The tool + function returns "alpha: " to confirm the right tool ran. + """ + async with _start_lifespan(starlette_app): + http_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), + headers={"Authorization": "Bearer token-alpha"}, + ) + async with http_client: + async with streamable_http_client( + url="http://localhost/mcp", + http_client=http_client, + ) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("query", {"sql": "SELECT 1"}) + texts = [c.text for c in result.content if isinstance(c, TextContent)] + assert any("alpha: SELECT 1" in t for t in texts) + + +async def test_whoami_returns_correct_tenant(starlette_app: Starlette): + """The whoami tool reports the authenticated tenant identity. + + This is the strongest proof that tenant_id propagates end-to-end: + the tool reads ``ctx.tenant_id`` (set by the low-level server from + ``tenant_id_var``, which was set by ``AuthContextMiddleware`` from + ``AccessToken.tenant_id``). Each tenant gets a different value. + + We test both tenants in a single test to verify isolation within + the same Starlette app instance (shared session manager). + """ + async with _start_lifespan(starlette_app): + # Test both tenants against the same running app + for token, expected_tenant in [("token-alpha", "alpha"), ("token-beta", "beta")]: + http_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), + headers={"Authorization": f"Bearer {token}"}, + ) + async with http_client: + async with streamable_http_client( + url="http://localhost/mcp", + http_client=http_client, + ) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("whoami", {}) + texts = [c.text for c in result.content if isinstance(c, TextContent)] + assert any(f"tenant={expected_tenant}" in t for t in texts) + + +async def test_unauthenticated_request_is_rejected(starlette_app: Starlette): + """A request without a bearer token is rejected by auth middleware. + + Verifies that the auth middleware (enabled by ``AuthSettings`` and + ``TokenVerifier``) returns HTTP 401 when no Authorization header is + present. This is a basic security check — without valid credentials, + no MCP session can be established. + + Unlike the other tests, this one sends a raw HTTP POST instead of + using the MCP client, since the client would fail to initialise + (which is the expected behaviour). + """ + async with _start_lifespan(starlette_app): + # No Authorization header — should be rejected + http_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), + ) + async with http_client: + # Send a raw JSON-RPC initialize request without auth + response = await http_client.post( + "http://localhost/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "0.1"}, + }, + }, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Auth middleware should reject with 401 Unauthorized + assert response.status_code == 401 diff --git a/tests/server/mcpserver/test_multi_tenancy_server.py b/tests/server/mcpserver/test_multi_tenancy_server.py new file mode 100644 index 000000000..d2798ce03 --- /dev/null +++ b/tests/server/mcpserver/test_multi_tenancy_server.py @@ -0,0 +1,291 @@ +"""Tests for tenant-scoped MCPServer server integration. + +Validates that tenant_id flows from MCPServer public methods down to the +underlying managers, and that Context exposes tenant_id correctly. +""" + +import pytest + +from mcp.server.experimental.request_context import Experimental +from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.prompts.base import Prompt +from mcp.server.mcpserver.resources.types import FunctionResource + +pytestmark = pytest.mark.anyio + + +# --- Context.tenant_id property --- + + +def test_context_tenant_id_without_request_context(): + """Context.tenant_id returns None when no request context is set.""" + ctx = Context() + assert ctx.tenant_id is None + + +def test_context_tenant_id_with_request_context(): + """Context.tenant_id returns the tenant_id from the request context.""" + from mcp.server.context import ServerRequestContext + + # Create a minimal ServerRequestContext with tenant_id + # We need real streams for ServerSession but won't use them + + rc = ServerRequestContext( + session=None, # type: ignore[arg-type] + lifespan_context=None, + experimental=Experimental(), + tenant_id="tenant-x", + ) + ctx = Context(request_context=rc) + assert ctx.tenant_id == "tenant-x" + + +def test_context_tenant_id_none_in_request_context(): + """Context.tenant_id returns None when request context has no tenant_id.""" + from mcp.server.context import ServerRequestContext + + rc = ServerRequestContext( + session=None, # type: ignore[arg-type] + lifespan_context=None, + experimental=Experimental(), + ) + ctx = Context(request_context=rc) + assert ctx.tenant_id is None + + +# --- MCPServer public methods with tenant_id --- + + +async def test_list_tools_with_tenant_id(): + """list_tools filters by tenant_id.""" + server = MCPServer("test") + + def tool_a() -> str: # pragma: no cover + return "a" + + def tool_b() -> str: # pragma: no cover + return "b" + + server.add_tool(tool_a, name="shared", tenant_id="tenant-a") + server.add_tool(tool_b, name="shared", tenant_id="tenant-b") + + tools_a = await server.list_tools(tenant_id="tenant-a") + tools_b = await server.list_tools(tenant_id="tenant-b") + tools_global = await server.list_tools() + + assert len(tools_a) == 1 + assert tools_a[0].name == "shared" + assert len(tools_b) == 1 + assert tools_b[0].name == "shared" + assert len(tools_global) == 0 + + +async def test_call_tool_with_tenant_id(): + """call_tool respects tenant scope.""" + server = MCPServer("test") + + def tool_a() -> str: + return "result-a" + + def tool_b() -> str: + return "result-b" + + server.add_tool(tool_a, name="do_work", tenant_id="tenant-a") + server.add_tool(tool_b, name="do_work", tenant_id="tenant-b") + + result_a = await server.call_tool("do_work", {}, tenant_id="tenant-a") + result_b = await server.call_tool("do_work", {}, tenant_id="tenant-b") + + # Structured output returns (unstructured_content, structured_content) + assert isinstance(result_a, tuple) + assert isinstance(result_b, tuple) + assert result_a[0][0].text == "result-a" # type: ignore[union-attr] + assert result_a[1] == {"result": "result-a"} + assert result_b[0][0].text == "result-b" # type: ignore[union-attr] + assert result_b[1] == {"result": "result-b"} + + +async def test_call_tool_wrong_tenant_raises(): + """Calling a tool under the wrong tenant raises an error.""" + from mcp.server.mcpserver.exceptions import ToolError + + server = MCPServer("test") + + def my_tool() -> str: # pragma: no cover + return "x" + + server.add_tool(my_tool, tenant_id="tenant-a") + + with pytest.raises(ToolError): + await server.call_tool("my_tool", {}, tenant_id="tenant-b") + + +async def test_list_resources_with_tenant_id(): + """list_resources filters by tenant_id.""" + server = MCPServer("test") + + resource_a = FunctionResource(uri="file:///data", name="data-a", fn=lambda: "a") + resource_b = FunctionResource(uri="file:///data", name="data-b", fn=lambda: "b") + + server.add_resource(resource_a, tenant_id="tenant-a") + server.add_resource(resource_b, tenant_id="tenant-b") + + resources_a = await server.list_resources(tenant_id="tenant-a") + resources_b = await server.list_resources(tenant_id="tenant-b") + resources_global = await server.list_resources() + + assert len(resources_a) == 1 + assert resources_a[0].name == "data-a" + assert len(resources_b) == 1 + assert resources_b[0].name == "data-b" + assert len(resources_global) == 0 + + +async def test_list_resource_templates_with_tenant_id(): + """list_resource_templates filters by tenant_id.""" + server = MCPServer("test") + + def greet_a(name: str) -> str: # pragma: no cover + return f"Hello A, {name}!" + + def greet_b(name: str) -> str: # pragma: no cover + return f"Hello B, {name}!" + + server._resource_manager.add_template( + fn=greet_a, + uri_template="greet://{name}", + tenant_id="tenant-a", + ) + server._resource_manager.add_template( + fn=greet_b, + uri_template="greet://{name}", + tenant_id="tenant-b", + ) + + templates_a = await server.list_resource_templates(tenant_id="tenant-a") + templates_b = await server.list_resource_templates(tenant_id="tenant-b") + templates_global = await server.list_resource_templates() + + assert len(templates_a) == 1 + assert len(templates_b) == 1 + assert len(templates_global) == 0 + + +async def test_read_resource_with_tenant_id(): + """read_resource respects tenant scope.""" + server = MCPServer("test") + + resource = FunctionResource(uri="file:///secret", name="secret", fn=lambda: "secret-data") + server.add_resource(resource, tenant_id="tenant-a") + + # Tenant A can read + results = await server.read_resource("file:///secret", tenant_id="tenant-a") + contents = list(results) + assert len(contents) == 1 + assert contents[0].content == "secret-data" + + # Tenant B cannot + from mcp.server.mcpserver.exceptions import ResourceError + + with pytest.raises(ResourceError, match="Unknown resource"): + await server.read_resource("file:///secret", tenant_id="tenant-b") + + +async def test_list_prompts_with_tenant_id(): + """list_prompts filters by tenant_id.""" + server = MCPServer("test") + + async def prompt_a() -> str: # pragma: no cover + return "Hello from A" + + async def prompt_b() -> str: # pragma: no cover + return "Hello from B" + + server.add_prompt(Prompt.from_function(prompt_a, name="greet"), tenant_id="tenant-a") + server.add_prompt(Prompt.from_function(prompt_b, name="greet"), tenant_id="tenant-b") + + prompts_a = await server.list_prompts(tenant_id="tenant-a") + prompts_b = await server.list_prompts(tenant_id="tenant-b") + prompts_global = await server.list_prompts() + + assert len(prompts_a) == 1 + assert len(prompts_b) == 1 + assert len(prompts_global) == 0 + + +async def test_get_prompt_with_tenant_id(): + """get_prompt respects tenant scope.""" + server = MCPServer("test") + + async def greet_a() -> str: + return "Hello from tenant-a" + + server.add_prompt(Prompt.from_function(greet_a, name="greet"), tenant_id="tenant-a") + + # Tenant A can get the prompt + result = await server.get_prompt("greet", tenant_id="tenant-a") + assert result.messages is not None + assert len(result.messages) > 0 + + # Tenant B cannot + with pytest.raises(ValueError, match="Unknown prompt"): + await server.get_prompt("greet", tenant_id="tenant-b") + + +async def test_remove_tool_with_tenant_id(): + """remove_tool respects tenant scope.""" + server = MCPServer("test") + + def my_tool() -> str: # pragma: no cover + return "x" + + server.add_tool(my_tool, name="my_tool", tenant_id="tenant-a") + server.add_tool(my_tool, name="my_tool", tenant_id="tenant-b") + + server.remove_tool("my_tool", tenant_id="tenant-a") + + tools_a = await server.list_tools(tenant_id="tenant-a") + tools_b = await server.list_tools(tenant_id="tenant-b") + + assert len(tools_a) == 0 + assert len(tools_b) == 1 + + +# --- Backward compatibility --- + + +async def test_backward_compat_no_tenant_id(): + """All public methods work without tenant_id (backward compatible).""" + server = MCPServer("test") + + @server.tool() + def greet(name: str) -> str: + return f"Hello, {name}!" + + @server.resource("test://data") + def test_resource() -> str: + return "data" + + @server.prompt() + def test_prompt() -> str: + return "prompt text" + + # All operations work without tenant_id + tools = await server.list_tools() + assert len(tools) == 1 + + result = await server.call_tool("greet", {"name": "World"}) + assert len(list(result)) > 0 + + resources = await server.list_resources() + assert len(resources) == 1 + + read_result = await server.read_resource("test://data") + assert len(list(read_result)) == 1 + + prompts = await server.list_prompts() + assert len(prompts) == 1 + + prompt_result = await server.get_prompt("test_prompt") + assert prompt_result.messages is not None diff --git a/tests/server/test_multi_tenancy_session.py b/tests/server/test_multi_tenancy_session.py new file mode 100644 index 000000000..2f94fa06d --- /dev/null +++ b/tests/server/test_multi_tenancy_session.py @@ -0,0 +1,424 @@ +"""Tests for multi-tenancy support in session and request context.""" + +import time + +import anyio +import pytest +from anyio.lowlevel import checkpoint + +from mcp import Client +from mcp.server import Server +from mcp.server.auth.middleware.auth_context import auth_context_var, get_tenant_id +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared._context import RequestContext, tenant_id_var +from mcp.shared.message import SessionMessage +from mcp.shared.session import BaseSession +from mcp.types import ListToolsResult, NotificationParams, PaginatedRequestParams, ServerCapabilities + + +def _simulate_tenant_binding(session: ServerSession, tenant_id_value: str) -> None: + """Simulate the set-once tenant binding logic from lowlevel/server.py. + + Sets both auth_context_var (as AuthContextMiddleware does) and tenant_id_var + (the transport-agnostic contextvar that the server reads). + """ + access_token = AccessToken( + token=f"token-{tenant_id_value}", + client_id="client", + scopes=["read"], + expires_at=int(time.time()) + 3600, + tenant_id=tenant_id_value, + ) + user = AuthenticatedUser(access_token) + auth_token = auth_context_var.set(user) + tenant_token = tenant_id_var.set(tenant_id_value) + try: + tenant_id = tenant_id_var.get() + if tenant_id is not None and session.tenant_id is None: + session.tenant_id = tenant_id + finally: + tenant_id_var.reset(tenant_token) + auth_context_var.reset(auth_token) + + +@pytest.fixture +def init_options() -> InitializationOptions: + """Create initialization options for testing.""" + return InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ) + + +def test_request_context_with_tenant_id(): + """Test RequestContext can hold tenant_id.""" + # Use type: ignore since we're testing the dataclass field, not session behavior + ctx: RequestContext[BaseSession] = RequestContext( # type: ignore[type-arg] + session=None, # type: ignore[arg-type] + request_id="test-1", + tenant_id="tenant-xyz", + ) + assert ctx.tenant_id == "tenant-xyz" + + +def test_request_context_without_tenant_id(): + """Test RequestContext defaults tenant_id to None.""" + ctx: RequestContext[BaseSession] = RequestContext( # type: ignore[type-arg] + session=None, # type: ignore[arg-type] + request_id="test-1", + ) + assert ctx.tenant_id is None + + +def test_server_request_context_with_tenant_id(): + """Test ServerRequestContext can hold tenant_id.""" + ctx = ServerRequestContext( + session=None, # type: ignore[arg-type] + lifespan_context={}, + experimental=Experimental( + task_metadata=None, + _client_capabilities=None, + _session=None, # type: ignore[arg-type] + _task_support=None, + ), + tenant_id="tenant-abc", + ) + assert ctx.tenant_id == "tenant-abc" + + +def test_server_request_context_inherits_tenant_id_from_base(): + """Test ServerRequestContext inherits tenant_id behavior from RequestContext.""" + # Without tenant_id + ctx_no_tenant = ServerRequestContext( + session=None, # type: ignore[arg-type] + lifespan_context={}, + experimental=Experimental( + task_metadata=None, + _client_capabilities=None, + _session=None, # type: ignore[arg-type] + _task_support=None, + ), + ) + assert ctx_no_tenant.tenant_id is None + + # With tenant_id + ctx_with_tenant = ServerRequestContext( + session=None, # type: ignore[arg-type] + lifespan_context={}, + experimental=Experimental( + task_metadata=None, + _client_capabilities=None, + _session=None, # type: ignore[arg-type] + _task_support=None, + ), + tenant_id="my-tenant", + ) + assert ctx_with_tenant.tenant_id == "my-tenant" + + +@pytest.mark.anyio +async def test_server_session_tenant_id_property(init_options: InitializationOptions): + """Test ServerSession tenant_id property with set-once semantics.""" + server_to_client_send, server_to_client_recv = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async with server_to_client_send, server_to_client_recv, client_to_server_send, client_to_server_recv: + async with ServerSession( + client_to_server_recv, + server_to_client_send, + init_options, + ) as session: + # Default tenant_id is None + assert session.tenant_id is None + + # Can set tenant_id + session.tenant_id = "tenant-123" + assert session.tenant_id == "tenant-123" + + # Setting to the same value is allowed + session.tenant_id = "tenant-123" + assert session.tenant_id == "tenant-123" + + # Cannot change to a different value + with pytest.raises(ValueError, match="Cannot change tenant_id"): + session.tenant_id = "tenant-456" + + # Cannot reset to None once set + with pytest.raises(ValueError, match="Cannot change tenant_id"): + session.tenant_id = None + + # Original value is preserved + assert session.tenant_id == "tenant-123" + + +def test_get_tenant_id_from_auth_context(): + """Test get_tenant_id extracts tenant_id from auth context.""" + # No auth context + assert get_tenant_id() is None + + # With auth context but no tenant + access_token_no_tenant = AccessToken( + token="token1", + client_id="client1", + scopes=["read"], + expires_at=int(time.time()) + 3600, + ) + user_no_tenant = AuthenticatedUser(access_token_no_tenant) + token = auth_context_var.set(user_no_tenant) + try: + assert get_tenant_id() is None + finally: + auth_context_var.reset(token) + + # With auth context and tenant + access_token_with_tenant = AccessToken( + token="token2", + client_id="client2", + scopes=["read"], + expires_at=int(time.time()) + 3600, + tenant_id="tenant-xyz", + ) + user_with_tenant = AuthenticatedUser(access_token_with_tenant) + token = auth_context_var.set(user_with_tenant) + try: + assert get_tenant_id() == "tenant-xyz" + finally: + auth_context_var.reset(token) + + +@pytest.mark.anyio +async def test_session_tenant_id_set_from_auth_context_on_first_request(init_options: InitializationOptions): + """Verify session.tenant_id is populated from auth context on the first request. + + The lowlevel server sets session.tenant_id from get_tenant_id() on the + first request that has a tenant. This test simulates that behavior directly. + """ + server_to_client_send, server_to_client_recv = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async with server_to_client_send, server_to_client_recv, client_to_server_send, client_to_server_recv: + async with ServerSession( + client_to_server_recv, + server_to_client_send, + init_options, + ) as session: + assert session.tenant_id is None + + # Simulate what lowlevel/server.py does: set session.tenant_id + # from auth context on first request + _simulate_tenant_binding(session, "tenant-first") + assert session.tenant_id == "tenant-first" + + # Simulate a second request with a different tenant — + # session.tenant_id should NOT change (set-once on first request) + _simulate_tenant_binding(session, "tenant-second") + + # Still the first tenant — not overwritten + assert session.tenant_id == "tenant-first" + + +@pytest.mark.anyio +async def test_tenant_context_isolation_between_concurrent_requests(): + """Verify tenant_id doesn't leak between concurrent async contexts. + + This test validates a critical security property: when multiple requests + from different tenants are processed concurrently, each request must only + see its own tenant_id, never another tenant's. + + How it works: + 1. We simulate two concurrent requests, each with a different tenant_id + ("tenant-A" and "tenant-B"). + + 2. Each simulated request: + - Creates an AccessToken with its tenant_id + - Sets it in the auth_context_var (the contextvar used for auth state) + - Yields control via checkpoint() to allow the other task to run + - Reads back the tenant_id via get_tenant_id() + - Stores the result for verification + + 3. The anyio.lowlevel.checkpoint() forces a context switch, creating + an opportunity for tenant context to "leak" if the isolation is + broken. Without proper contextvar isolation, task2 might see + task1's tenant_id (or vice versa) after the context switch. + + 4. We use anyio.create_task_group() to run both tasks truly concurrently, + not sequentially. This is essential for testing isolation. + + 5. Finally, we verify each request saw only its own tenant_id. + + If this test fails, it indicates a serious security issue where tenant + data could leak between concurrent requests. + """ + # Store results from each simulated request + results: dict[str, str | None] = {} + + async def simulate_request(tenant_id: str, request_key: str) -> None: + """Simulate a request with a specific tenant context. + + Args: + tenant_id: The tenant_id to set in the auth context + request_key: A key to identify this request's result + """ + # Create an access token with the tenant_id, simulating what + # the auth middleware does when a request comes in + access_token = AccessToken( + token=f"token-{request_key}", + client_id="test-client", + scopes=["read"], + expires_at=int(time.time()) + 3600, + tenant_id=tenant_id, + ) + user = AuthenticatedUser(access_token) + + # Set both contextvars - this is what AuthContextMiddleware does + auth_token = auth_context_var.set(user) + tenant_token = tenant_id_var.set(tenant_id) + try: + # Yield control to allow other tasks to run. This is the critical + # point where context leakage could occur if isolation is broken. + await checkpoint() + + # Read back the tenant_id - should still be our tenant, not the other + results[request_key] = tenant_id_var.get() + finally: + # Always reset the context (mirrors middleware behavior) + tenant_id_var.reset(tenant_token) + auth_context_var.reset(auth_token) + + # Run both requests concurrently using a task group + async with anyio.create_task_group() as tg: + tg.start_soon(simulate_request, "tenant-A", "request1") + tg.start_soon(simulate_request, "tenant-B", "request2") + + # Verify isolation: each request should see only its own tenant_id + assert results["request1"] == "tenant-A", "Request 1 saw wrong tenant_id" + assert results["request2"] == "tenant-B", "Request 2 saw wrong tenant_id" + + +@pytest.mark.anyio +async def test_server_session_isolation_between_instances(init_options: InitializationOptions): + """Verify tenant_id is isolated between separate ServerSession instances. + + This test ensures that setting tenant_id on one ServerSession does not + affect another ServerSession instance. Each session should maintain its + own independent tenant context. + + This is important for scenarios where a server handles multiple sessions + concurrently - each session belongs to a specific tenant and must not + see or affect other tenants' sessions. + """ + # Create streams for two independent sessions + send1, recv1 = anyio.create_memory_object_stream[SessionMessage](1) + send2, recv2 = anyio.create_memory_object_stream[SessionMessage | Exception](1) + send3, recv3 = anyio.create_memory_object_stream[SessionMessage](1) + send4, recv4 = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async with send1, recv1, send2, recv2, send3, recv3, send4, recv4: + # Create two separate server sessions + async with ( + ServerSession(recv2, send1, init_options) as session1, + ServerSession(recv4, send3, init_options) as session2, + ): + # Set different tenant_ids on each session + session1.tenant_id = "tenant-alpha" + session2.tenant_id = "tenant-beta" + + # Verify each session maintains its own tenant_id + assert session1.tenant_id == "tenant-alpha" + assert session2.tenant_id == "tenant-beta" + + # Attempting to change one session's tenant_id raises + with pytest.raises(ValueError, match="Cannot change tenant_id"): + session1.tenant_id = "tenant-gamma" + + # Both sessions retain their original values + assert session1.tenant_id == "tenant-alpha" + assert session2.tenant_id == "tenant-beta" + + +@pytest.mark.anyio +async def test_handle_request_populates_session_tenant_id(): + """E2E: session.tenant_id is set from auth context during request handling. + + This exercises the set-once tenant binding in lowlevel/server.py + _handle_request, covering the branch where get_tenant_id() returns + a non-None value. + """ + captured_ctx_tenant: str | None = None + captured_session_tenant: str | None = None + + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + nonlocal captured_ctx_tenant, captured_session_tenant + captured_ctx_tenant = ctx.tenant_id + captured_session_tenant = ctx.session.tenant_id + return ListToolsResult(tools=[]) + + server = Server("test", on_list_tools=handle_list_tools) + + # Set auth context with tenant before entering the Client — + # contextvars are inherited by child tasks, so the server will see it + access_token = AccessToken( + token="test-token", + client_id="test-client", + scopes=["read"], + expires_at=int(time.time()) + 3600, + tenant_id="tenant-e2e", + ) + user = AuthenticatedUser(access_token) + auth_token = auth_context_var.set(user) + tenant_token = tenant_id_var.set("tenant-e2e") + try: + async with Client(server) as client: + await client.list_tools() + finally: + tenant_id_var.reset(tenant_token) + auth_context_var.reset(auth_token) + + assert captured_ctx_tenant == "tenant-e2e" + assert captured_session_tenant == "tenant-e2e" + + +@pytest.mark.anyio +async def test_handle_notification_populates_session_tenant_id(): + """E2E: session.tenant_id is set from auth context during notification handling. + + This exercises the set-once tenant binding in lowlevel/server.py + _handle_notification, covering the branch where get_tenant_id() returns + a non-None value. + """ + notification_tenant: str | None = None + notification_received = anyio.Event() + + async def handle_roots_list_changed(ctx: ServerRequestContext, params: NotificationParams | None) -> None: + nonlocal notification_tenant + notification_tenant = ctx.tenant_id + notification_received.set() + + server = Server("test", on_roots_list_changed=handle_roots_list_changed) + + access_token = AccessToken( + token="test-token", + client_id="test-client", + scopes=["read"], + expires_at=int(time.time()) + 3600, + tenant_id="tenant-notify", + ) + user = AuthenticatedUser(access_token) + auth_token = auth_context_var.set(user) + tenant_token = tenant_id_var.set("tenant-notify") + try: + async with Client(server) as client: + await client.session.send_roots_list_changed() + with anyio.fail_after(5): + await notification_received.wait() + finally: + tenant_id_var.reset(tenant_token) + auth_context_var.reset(auth_token) + + assert notification_tenant == "tenant-notify" diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..35caa1229 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -413,3 +413,216 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +# --- Multi-tenancy: session-level tenant isolation --- + + +def _extract_session_id(messages: list[Message]) -> str | None: + """Extract the MCP session ID from ASGI response messages.""" + for msg in messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + return header_value.decode() + return None + + +def _extract_status(messages: list[Message]) -> int | None: + """Extract the HTTP status code from ASGI response messages.""" + for msg in messages: + if msg["type"] == "http.response.start": + return msg["status"] + return None + + +def test_extract_session_id_skips_non_start_messages(): + """_extract_session_id skips non-start messages and returns None when no ID found.""" + body_msg: Message = {"type": "http.response.body", "body": b"data"} + start_no_header: Message = {"type": "http.response.start", "status": 200, "headers": []} + + # Only body messages → None + assert _extract_session_id([body_msg]) is None + # Start message without session header → None + assert _extract_session_id([body_msg, start_no_header]) is None + + +def test_extract_status_skips_non_start_messages(): + """_extract_status skips non-start messages and returns None when empty.""" + body_msg: Message = {"type": "http.response.body", "body": b"data"} + start_msg: Message = {"type": "http.response.start", "status": 200, "headers": []} + + # Only body messages → None + assert _extract_status([body_msg]) is None + # Body then start → returns status from start + assert _extract_status([body_msg, start_msg]) == 200 + # Empty list → None + assert _extract_status([]) is None + + +def _make_scope(session_id: str | None = None) -> dict[str, Any]: + """Build a minimal ASGI scope for testing, optionally with a session ID.""" + headers: list[tuple[bytes, bytes]] = [(b"content-type", b"application/json")] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + return {"type": "http", "method": "POST", "path": "/mcp", "headers": headers} + + +async def _mock_send(messages: list[Message], message: Message) -> None: + """Async send that collects messages.""" + messages.append(message) + + +async def _mock_receive() -> dict[str, Any]: # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + +def _set_tenant(tenant: str | None) -> Any: + """Set tenant_id_var and return the reset token.""" + from mcp.shared._context import tenant_id_var + + return tenant_id_var.set(tenant) + + +def _reset_tenant(token: Any) -> None: + """Reset tenant_id_var to its previous value.""" + from mcp.shared._context import tenant_id_var + + tenant_id_var.reset(token) + + +async def _create_session_blocking( + manager: StreamableHTTPSessionManager, + app: Server[Any], + stop_event: anyio.Event, + tenant: str | None = None, +) -> str: + """Create a session whose server stays alive until stop_event is set.""" + + async def blocking_run(*args: Any, **kwargs: Any) -> None: + await stop_event.wait() + + app.run = AsyncMock(side_effect=blocking_run) + + messages: list[Message] = [] + token = _set_tenant(tenant) + try: + await manager.handle_request(_make_scope(), _mock_receive, lambda msg, _msgs=messages: _mock_send(_msgs, msg)) + finally: + _reset_tenant(token) + + session_id = _extract_session_id(messages) + assert session_id is not None + return session_id + + +async def _access_session( + manager: StreamableHTTPSessionManager, + session_id: str, + tenant: str | None = None, +) -> int | None: + """Access an existing session and return the HTTP status code.""" + messages: list[Message] = [] + token = _set_tenant(tenant) + try: + await manager.handle_request( + _make_scope(session_id), _mock_receive, lambda msg, _msgs=messages: _mock_send(_msgs, msg) + ) + finally: + _reset_tenant(token) + + return _extract_status(messages) + + +@pytest.mark.anyio +async def test_tenant_mismatch_returns_404(running_manager: tuple[StreamableHTTPSessionManager, Server]): + """A request from tenant-b cannot access a session created by tenant-a.""" + manager, app = running_manager + stop = anyio.Event() + try: + session_id = await _create_session_blocking(manager, app, stop, tenant="tenant-a") + assert await _access_session(manager, session_id, tenant="tenant-b") == 404 + finally: + stop.set() + + +@pytest.mark.anyio +async def test_two_tenants_cannot_access_each_others_sessions( + running_manager: tuple[StreamableHTTPSessionManager, Server], +): + """Two tenants each create a session; neither can access the other's.""" + manager, app = running_manager + stop = anyio.Event() + try: + session_a = await _create_session_blocking(manager, app, stop, tenant="tenant-a") + session_b = await _create_session_blocking(manager, app, stop, tenant="tenant-b") + assert session_a != session_b + + # Tenant-a tries to access tenant-b's session → 404 + assert await _access_session(manager, session_b, tenant="tenant-a") == 404 + # Tenant-b tries to access tenant-a's session → 404 + assert await _access_session(manager, session_a, tenant="tenant-b") == 404 + finally: + stop.set() + + +@pytest.mark.anyio +async def test_same_tenant_can_reuse_session(running_manager: tuple[StreamableHTTPSessionManager, Server]): + """A request from the same tenant can access its own session.""" + manager, app = running_manager + stop = anyio.Event() + try: + session_id = await _create_session_blocking(manager, app, stop, tenant="tenant-a") + status = await _access_session(manager, session_id, tenant="tenant-a") + assert status != 404, "Same tenant should be able to reuse its own session" + finally: + stop.set() + + +@pytest.mark.anyio +async def test_no_tenant_session_allows_any_access(running_manager: tuple[StreamableHTTPSessionManager, Server]): + """Sessions created without a tenant (no auth) allow access from any request.""" + manager, app = running_manager + stop = anyio.Event() + try: + session_id = await _create_session_blocking(manager, app, stop, tenant=None) + status = await _access_session(manager, session_id, tenant="tenant-a") + assert status != 404, "Session without tenant binding should allow access from any tenant" + finally: + stop.set() + + +@pytest.mark.anyio +async def test_unauthenticated_request_cannot_access_tenant_session( + running_manager: tuple[StreamableHTTPSessionManager, Server], +): + """A request with no tenant cannot access a session bound to a tenant.""" + manager, app = running_manager + stop = anyio.Event() + try: + session_id = await _create_session_blocking(manager, app, stop, tenant="tenant-a") + assert await _access_session(manager, session_id, tenant=None) == 404 + finally: + stop.set() + + +@pytest.mark.anyio +async def test_session_tenant_cleanup_on_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): + """Tenant mapping is cleaned up when a session exits.""" + manager, app = running_manager + app.run = AsyncMock(return_value=None) + + messages: list[Message] = [] + token = _set_tenant("tenant-a") + try: + await manager.handle_request(_make_scope(), _mock_receive, lambda msg, _msgs=messages: _mock_send(_msgs, msg)) + finally: + _reset_tenant(token) + + session_id = _extract_session_id(messages) + assert session_id is not None + + # Wait for the mock server to complete and cleanup to run + await anyio.sleep(0.01) + + assert session_id not in manager._session_tenants, "Tenant mapping should be cleaned up after session exits" diff --git a/uv.lock b/uv.lock index 4af3532ea..b05c956a7 100644 --- a/uv.lock +++ b/uv.lock @@ -13,6 +13,7 @@ members = [ "mcp-simple-auth", "mcp-simple-auth-client", "mcp-simple-chatbot", + "mcp-simple-multi-tenant", "mcp-simple-pagination", "mcp-simple-prompt", "mcp-simple-resource", @@ -1035,6 +1036,35 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-multi-tenant" +version = "0.1.0" +source = { editable = "examples/servers/simple-multi-tenant" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.2.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-pagination" version = "0.1.0"