|
1 | 1 | """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import time |
4 | 5 | from asyncio import Lock |
5 | 6 | from collections.abc import Callable |
6 | 7 | from typing import Any |
|
17 | 18 | from fastapi import HTTPException, Request |
18 | 19 |
|
19 | 20 | from authentication.interface import AuthInterface, AuthTuple |
20 | | -from authentication.utils import extract_user_token |
| 21 | +from authentication.utils import extract_user_token, record_auth_metrics |
21 | 22 | from constants import ( |
| 23 | + AUTH_MOD_JWK_TOKEN, |
22 | 24 | DEFAULT_VIRTUAL_PATH, |
23 | 25 | ) |
24 | 26 | from log import get_logger |
@@ -139,6 +141,88 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key: |
139 | 141 | return _internal |
140 | 142 |
|
141 | 143 |
|
| 144 | +async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet: |
| 145 | + """Load the configured JWK set and record bounded auth failures.""" |
| 146 | + try: |
| 147 | + return await get_jwk_set(str(config.url)) |
| 148 | + except aiohttp.ClientError as exc: |
| 149 | + logger.error("Failed to fetch JWK set: %s", exc) |
| 150 | + record_auth_metrics( |
| 151 | + AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time |
| 152 | + ) |
| 153 | + response = UnauthorizedResponse( |
| 154 | + cause="Unable to reach authentication key server" |
| 155 | + ) |
| 156 | + raise HTTPException(**response.model_dump()) from exc |
| 157 | + except json.JSONDecodeError as exc: |
| 158 | + logger.error("Invalid JSON in JWK set response: %s", exc) |
| 159 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time) |
| 160 | + response = UnauthorizedResponse( |
| 161 | + cause="Authentication key server returned invalid data" |
| 162 | + ) |
| 163 | + raise HTTPException(**response.model_dump()) from exc |
| 164 | + except JoseError as exc: |
| 165 | + logger.error("Invalid JWK set format: %s", exc) |
| 166 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time) |
| 167 | + response = UnauthorizedResponse(cause="Authentication keys are malformed") |
| 168 | + raise HTTPException(**response.model_dump()) from exc |
| 169 | + |
| 170 | + |
| 171 | +def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any: |
| 172 | + """Decode a JWT and record bounded auth failures.""" |
| 173 | + try: |
| 174 | + return jwt.decode(user_token, key=key_resolver_func(jwk_set)) |
| 175 | + except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: |
| 176 | + logger.warning("Token decode error: %s", exc) |
| 177 | + record_auth_metrics( |
| 178 | + AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time |
| 179 | + ) |
| 180 | + if isinstance(exc, KeyNotFoundError): |
| 181 | + cause = "Token signed by unknown key" |
| 182 | + elif isinstance(exc, BadSignatureError): |
| 183 | + cause = "Invalid token signature" |
| 184 | + elif isinstance(exc, DecodeError): |
| 185 | + cause = "Token could not be decoded" |
| 186 | + else: |
| 187 | + cause = "Token format error" |
| 188 | + response = UnauthorizedResponse(cause=cause) |
| 189 | + raise HTTPException(**response.model_dump()) from exc |
| 190 | + |
| 191 | + |
| 192 | +def _validate_jwk_claims(claims: Any, start_time: float) -> None: |
| 193 | + """Validate decoded JWT claims and record bounded auth failures.""" |
| 194 | + try: |
| 195 | + claims.validate() |
| 196 | + except ExpiredTokenError as exc: |
| 197 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time) |
| 198 | + response = UnauthorizedResponse(cause="Token has expired") |
| 199 | + raise HTTPException(**response.model_dump()) from exc |
| 200 | + except JoseError as exc: |
| 201 | + record_auth_metrics( |
| 202 | + AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time |
| 203 | + ) |
| 204 | + response = UnauthorizedResponse(cause="Token validation failed") |
| 205 | + raise HTTPException(**response.model_dump()) from exc |
| 206 | + |
| 207 | + |
| 208 | +def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str: |
| 209 | + """Return a required JWT claim and record bounded auth failures when missing.""" |
| 210 | + try: |
| 211 | + value = claims[claim_name] |
| 212 | + except KeyError as exc: |
| 213 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time) |
| 214 | + response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}") |
| 215 | + raise HTTPException(**response.model_dump()) from exc |
| 216 | + if not isinstance(value, str) or not value: |
| 217 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time) |
| 218 | + response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}") |
| 219 | + invalid_claim_error = ValueError( |
| 220 | + f"Token claim {claim_name} must be a non-empty string" |
| 221 | + ) |
| 222 | + raise HTTPException(**response.model_dump()) from invalid_claim_error |
| 223 | + return value |
| 224 | + |
| 225 | + |
142 | 226 | class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods |
143 | 227 | """JWK AuthDependency class for JWK-based JWT authentication.""" |
144 | 228 |
|
@@ -187,73 +271,40 @@ async def __call__(self, request: Request) -> AuthTuple: |
187 | 271 | extracted from the validated JWT. Only returned on successful |
188 | 272 | authentication; all error paths raise HTTPException. |
189 | 273 | """ |
| 274 | + start_time = time.monotonic() |
| 275 | + |
190 | 276 | if not request.headers.get("Authorization"): |
| 277 | + record_auth_metrics( |
| 278 | + AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time |
| 279 | + ) |
191 | 280 | response = UnauthorizedResponse(cause="No Authorization header found") |
192 | 281 | raise HTTPException(**response.model_dump()) |
193 | 282 |
|
194 | | - user_token = extract_user_token(request.headers) |
195 | | - |
196 | | - try: |
197 | | - jwk_set = await get_jwk_set(str(self.config.url)) |
198 | | - except aiohttp.ClientError as exc: |
199 | | - logger.error("Failed to fetch JWK set: %s", exc) |
200 | | - response = UnauthorizedResponse( |
201 | | - cause="Unable to reach authentication key server" |
202 | | - ) |
203 | | - raise HTTPException(**response.model_dump()) from exc |
204 | | - except json.JSONDecodeError as exc: |
205 | | - logger.error("Invalid JSON in JWK set response: %s", exc) |
206 | | - response = UnauthorizedResponse( |
207 | | - cause="Authentication key server returned invalid data" |
208 | | - ) |
209 | | - raise HTTPException(**response.model_dump()) from exc |
210 | | - except JoseError as exc: |
211 | | - logger.error("Invalid JWK set format: %s", exc) |
212 | | - response = UnauthorizedResponse(cause="Authentication keys are malformed") |
213 | | - raise HTTPException(**response.model_dump()) from exc |
214 | | - |
215 | | - try: |
216 | | - claims = jwt.decode(user_token, key=key_resolver_func(jwk_set)) |
217 | | - except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: |
218 | | - logger.warning("Token decode error: %s", exc) |
219 | | - cause_map = { |
220 | | - KeyNotFoundError: "Token signed by unknown key", |
221 | | - BadSignatureError: "Invalid token signature", |
222 | | - DecodeError: "Token could not be decoded", |
223 | | - JoseError: "Token format error", |
224 | | - } |
225 | | - response = UnauthorizedResponse( |
226 | | - cause=cause_map.get(type(exc), "Unknown token error") |
227 | | - ) |
228 | | - raise HTTPException(**response.model_dump()) from exc |
229 | | - |
230 | 283 | try: |
231 | | - claims.validate() |
232 | | - except ExpiredTokenError as exc: |
233 | | - response = UnauthorizedResponse(cause="Token has expired") |
234 | | - raise HTTPException(**response.model_dump()) from exc |
235 | | - except JoseError as exc: |
236 | | - response = UnauthorizedResponse(cause="Token validation failed") |
237 | | - raise HTTPException(**response.model_dump()) from exc |
238 | | - |
239 | | - try: |
240 | | - user_id: str = claims[self.config.jwt_configuration.user_id_claim] |
241 | | - except KeyError as exc: |
242 | | - missing_claim = self.config.jwt_configuration.user_id_claim |
243 | | - response = UnauthorizedResponse( |
244 | | - cause=f"Token missing claim: {missing_claim}" |
| 284 | + user_token = extract_user_token(request.headers) |
| 285 | + except HTTPException: |
| 286 | + record_auth_metrics( |
| 287 | + AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time |
245 | 288 | ) |
246 | | - raise HTTPException(**response.model_dump()) from exc |
247 | | - |
248 | | - try: |
249 | | - username: str = claims[self.config.jwt_configuration.username_claim] |
250 | | - except KeyError as exc: |
251 | | - missing_claim = self.config.jwt_configuration.username_claim |
252 | | - response = UnauthorizedResponse( |
253 | | - cause=f"Token missing claim: {missing_claim}" |
| 289 | + raise |
| 290 | + except Exception: # pylint: disable=broad-exception-caught |
| 291 | + logger.exception("Unexpected error while extracting JWK bearer token") |
| 292 | + record_auth_metrics( |
| 293 | + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", start_time |
254 | 294 | ) |
255 | | - raise HTTPException(**response.model_dump()) from exc |
| 295 | + raise |
| 296 | + |
| 297 | + jwk_set = await _get_jwk_set_for_auth(self.config, start_time) |
| 298 | + claims = _decode_jwk_claims(user_token, jwk_set, start_time) |
| 299 | + _validate_jwk_claims(claims, start_time) |
| 300 | + user_id = _get_required_claim( |
| 301 | + claims, self.config.jwt_configuration.user_id_claim, start_time |
| 302 | + ) |
| 303 | + username = _get_required_claim( |
| 304 | + claims, self.config.jwt_configuration.username_claim, start_time |
| 305 | + ) |
256 | 306 |
|
257 | 307 | logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) |
258 | 308 |
|
| 309 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time) |
259 | 310 | return user_id, username, self.skip_userid_check, user_token |
0 commit comments