Skip to content

Commit 85ce44b

Browse files
fix: retrieve client_id first then compare auth method
Retrieve client_id from auth header or the body, then retrieve client via that client_id. After that, compare auth method.
1 parent ea1f8c7 commit 85ce44b

1 file changed

Lines changed: 15 additions & 43 deletions

File tree

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,61 +60,33 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
6060
Raises:
6161
AuthenticationError: If authentication fails
6262
"""
63-
form_data = await request.form()
64-
client_id = form_data.get("client_id")
65-
if not client_id:
66-
raise AuthenticationError("Missing client_id")
67-
68-
client = await self.provider.get_client(str(client_id))
63+
client_credentials = await self._get_credentials(request)
64+
client = await self.provider.get_client(str(client_credentials.client_id))
6965
if not client:
7066
raise AuthenticationError("Invalid client_id") # pragma: no cover
7167

72-
request_client_secret: str | None = None
73-
auth_header = request.headers.get("Authorization", "")
74-
75-
if client.token_endpoint_auth_method == "client_secret_basic":
76-
if not auth_header.startswith("Basic "):
77-
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
78-
79-
try:
80-
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
81-
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
82-
if ":" not in decoded:
83-
raise ValueError("Invalid Basic auth format")
84-
basic_client_id, request_client_secret = decoded.split(":", 1)
85-
86-
# URL-decode both parts per RFC 6749 Section 2.3.1
87-
basic_client_id = unquote(basic_client_id)
88-
request_client_secret = unquote(request_client_secret)
89-
90-
if basic_client_id != client_id:
91-
raise AuthenticationError("Client ID mismatch in Basic auth")
92-
except (ValueError, UnicodeDecodeError, binascii.Error):
93-
raise AuthenticationError("Invalid Basic authentication header")
94-
95-
elif client.token_endpoint_auth_method == "client_secret_post":
96-
raw_form_data = form_data.get("client_secret")
97-
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
98-
if isinstance(raw_form_data, str):
99-
request_client_secret = str(raw_form_data)
100-
101-
elif client.token_endpoint_auth_method == "none":
102-
request_client_secret = None
103-
else:
104-
raise AuthenticationError( # pragma: no cover
105-
f"Unsupported auth method: {client.token_endpoint_auth_method}"
106-
)
68+
match client.token_endpoint_auth_method:
69+
case "client_secret_basic":
70+
if client_credentials.auth_method != "client_secret_basic":
71+
raise AuthenticationError(f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}")
72+
case "client_secret_post":
73+
if client_credentials.auth_method != "client_secret_post":
74+
raise AuthenticationError(f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}")
75+
case "none":
76+
pass
77+
case _:
78+
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # pragma: no cover
10779

10880
# If client from the store expects a secret, validate that the request provides
10981
# that secret
11082
if client.client_secret: # pragma: no branch
111-
if not request_client_secret:
83+
if not client_credentials.client_secret:
11284
raise AuthenticationError("Client secret is required") # pragma: no cover
11385

11486
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
11587
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
11688
# arguments to bytes.
117-
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
89+
if not hmac.compare_digest(client.client_secret.encode(), client_credentials.client_secret.encode()):
11890
raise AuthenticationError("Invalid client_secret") # pragma: no cover
11991

12092
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):

0 commit comments

Comments
 (0)