@@ -60,61 +60,33 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
6060 Raises:
6161 AuthenticationError: If authentication fails
6262 """
63- form_data = await request .form ()
64- client_id = form_data .get ("client_id" )
65- if not client_id :
66- raise AuthenticationError ("Missing client_id" )
67-
68- client = await self .provider .get_client (str (client_id ))
63+ client_credentials = await self ._get_credentials (request )
64+ client = await self .provider .get_client (str (client_credentials .client_id ))
6965 if not client :
7066 raise AuthenticationError ("Invalid client_id" ) # pragma: no cover
7167
72- request_client_secret : str | None = None
73- auth_header = request .headers .get ("Authorization" , "" )
74-
75- if client .token_endpoint_auth_method == "client_secret_basic" :
76- if not auth_header .startswith ("Basic " ):
77- raise AuthenticationError ("Missing or invalid Basic authentication in Authorization header" )
78-
79- try :
80- encoded_credentials = auth_header [6 :] # Remove "Basic " prefix
81- decoded = base64 .b64decode (encoded_credentials ).decode ("utf-8" )
82- if ":" not in decoded :
83- raise ValueError ("Invalid Basic auth format" )
84- basic_client_id , request_client_secret = decoded .split (":" , 1 )
85-
86- # URL-decode both parts per RFC 6749 Section 2.3.1
87- basic_client_id = unquote (basic_client_id )
88- request_client_secret = unquote (request_client_secret )
89-
90- if basic_client_id != client_id :
91- raise AuthenticationError ("Client ID mismatch in Basic auth" )
92- except (ValueError , UnicodeDecodeError , binascii .Error ):
93- raise AuthenticationError ("Invalid Basic authentication header" )
94-
95- elif client .token_endpoint_auth_method == "client_secret_post" :
96- raw_form_data = form_data .get ("client_secret" )
97- # form_data.get() can return a UploadFile or None, so we need to check if it's a string
98- if isinstance (raw_form_data , str ):
99- request_client_secret = str (raw_form_data )
100-
101- elif client .token_endpoint_auth_method == "none" :
102- request_client_secret = None
103- else :
104- raise AuthenticationError ( # pragma: no cover
105- f"Unsupported auth method: { client .token_endpoint_auth_method } "
106- )
68+ match client .token_endpoint_auth_method :
69+ case "client_secret_basic" :
70+ if client_credentials .auth_method != "client_secret_basic" :
71+ raise AuthenticationError (f"Expected client_secret_basic authentication method, but got { client_credentials .auth_method } " )
72+ case "client_secret_post" :
73+ if client_credentials .auth_method != "client_secret_post" :
74+ raise AuthenticationError (f"Expected client_secret_post authentication method, but got { client_credentials .auth_method } " )
75+ case "none" :
76+ pass
77+ case _:
78+ raise AuthenticationError (f"Unsupported auth method: { client .token_endpoint_auth_method } " ) # pragma: no cover
10779
10880 # If client from the store expects a secret, validate that the request provides
10981 # that secret
11082 if client .client_secret : # pragma: no branch
111- if not request_client_secret :
83+ if not client_credentials . client_secret :
11284 raise AuthenticationError ("Client secret is required" ) # pragma: no cover
11385
11486 # hmac.compare_digest requires that both arguments are either bytes or a `str` containing
11587 # only ASCII characters. Since we do not control `request_client_secret`, we encode both
11688 # arguments to bytes.
117- if not hmac .compare_digest (client .client_secret .encode (), request_client_secret .encode ()):
89+ if not hmac .compare_digest (client .client_secret .encode (), client_credentials . client_secret .encode ()):
11890 raise AuthenticationError ("Invalid client_secret" ) # pragma: no cover
11991
12092 if client .client_secret_expires_at and client .client_secret_expires_at < int (time .time ()):
0 commit comments