Skip to content

Commit 046a1a5

Browse files
committed
feat: add auth monitoring metrics
Signed-off-by: Major Hayden <major@redhat.com>
1 parent ca125c4 commit 046a1a5

14 files changed

Lines changed: 771 additions & 140 deletions

File tree

src/authentication/api_key_token.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
"""
88

99
import secrets
10+
import time
1011

1112
from fastapi import HTTPException, Request, status
1213

1314
from authentication.interface import AuthInterface
14-
from authentication.utils import extract_user_token
15+
from authentication.utils import extract_user_token, record_auth_metrics
1516
from constants import (
17+
AUTH_MOD_APIKEY_TOKEN,
1618
DEFAULT_USER_NAME,
1719
DEFAULT_USER_UID,
1820
DEFAULT_VIRTUAL_PATH,
@@ -59,16 +61,28 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]:
5961
HTTPException: If the bearer token is missing or
6062
doesn't match the configured API key (HTTP 401).
6163
"""
64+
start_time = time.monotonic()
65+
6266
# try to extract user token from request
63-
user_token = extract_user_token(request.headers)
67+
try:
68+
user_token = extract_user_token(request.headers)
69+
except HTTPException:
70+
record_auth_metrics(
71+
AUTH_MOD_APIKEY_TOKEN, "failure", "missing_token", start_time
72+
)
73+
raise
6474

6575
# API Key validation. Use secrets.compare_digest for constant-time comparison
6676
if not secrets.compare_digest(
6777
user_token, self.config.api_key.get_secret_value()
6878
):
79+
record_auth_metrics(
80+
AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time
81+
)
6982
raise HTTPException(
7083
status_code=status.HTTP_401_UNAUTHORIZED,
7184
detail="Invalid API Key",
7285
)
7386

87+
record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time)
7488
return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token

src/authentication/jwk_token.py

Lines changed: 111 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""
22

33
import json
4+
import time
45
from asyncio import Lock
56
from collections.abc import Callable
67
from typing import Any
@@ -17,8 +18,9 @@
1718
from fastapi import HTTPException, Request
1819

1920
from authentication.interface import AuthInterface, AuthTuple
20-
from authentication.utils import extract_user_token
21+
from authentication.utils import extract_user_token, record_auth_metrics
2122
from constants import (
23+
AUTH_MOD_JWK_TOKEN,
2224
DEFAULT_VIRTUAL_PATH,
2325
)
2426
from log import get_logger
@@ -139,6 +141,88 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key:
139141
return _internal
140142

141143

144+
async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet:
145+
"""Load the configured JWK set and record bounded auth failures."""
146+
try:
147+
return await get_jwk_set(str(config.url))
148+
except aiohttp.ClientError as exc:
149+
logger.error("Failed to fetch JWK set: %s", exc)
150+
record_auth_metrics(
151+
AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time
152+
)
153+
response = UnauthorizedResponse(
154+
cause="Unable to reach authentication key server"
155+
)
156+
raise HTTPException(**response.model_dump()) from exc
157+
except json.JSONDecodeError as exc:
158+
logger.error("Invalid JSON in JWK set response: %s", exc)
159+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time)
160+
response = UnauthorizedResponse(
161+
cause="Authentication key server returned invalid data"
162+
)
163+
raise HTTPException(**response.model_dump()) from exc
164+
except JoseError as exc:
165+
logger.error("Invalid JWK set format: %s", exc)
166+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time)
167+
response = UnauthorizedResponse(cause="Authentication keys are malformed")
168+
raise HTTPException(**response.model_dump()) from exc
169+
170+
171+
def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any:
172+
"""Decode a JWT and record bounded auth failures."""
173+
try:
174+
return jwt.decode(user_token, key=key_resolver_func(jwk_set))
175+
except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc:
176+
logger.warning("Token decode error: %s", exc)
177+
record_auth_metrics(
178+
AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time
179+
)
180+
if isinstance(exc, KeyNotFoundError):
181+
cause = "Token signed by unknown key"
182+
elif isinstance(exc, BadSignatureError):
183+
cause = "Invalid token signature"
184+
elif isinstance(exc, DecodeError):
185+
cause = "Token could not be decoded"
186+
else:
187+
cause = "Token format error"
188+
response = UnauthorizedResponse(cause=cause)
189+
raise HTTPException(**response.model_dump()) from exc
190+
191+
192+
def _validate_jwk_claims(claims: Any, start_time: float) -> None:
193+
"""Validate decoded JWT claims and record bounded auth failures."""
194+
try:
195+
claims.validate()
196+
except ExpiredTokenError as exc:
197+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time)
198+
response = UnauthorizedResponse(cause="Token has expired")
199+
raise HTTPException(**response.model_dump()) from exc
200+
except JoseError as exc:
201+
record_auth_metrics(
202+
AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time
203+
)
204+
response = UnauthorizedResponse(cause="Token validation failed")
205+
raise HTTPException(**response.model_dump()) from exc
206+
207+
208+
def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str:
209+
"""Return a required JWT claim and record bounded auth failures when missing."""
210+
try:
211+
value = claims[claim_name]
212+
except KeyError as exc:
213+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time)
214+
response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}")
215+
raise HTTPException(**response.model_dump()) from exc
216+
if not isinstance(value, str) or not value:
217+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time)
218+
response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}")
219+
invalid_claim_error = ValueError(
220+
f"Token claim {claim_name} must be a non-empty string"
221+
)
222+
raise HTTPException(**response.model_dump()) from invalid_claim_error
223+
return value
224+
225+
142226
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
143227
"""JWK AuthDependency class for JWK-based JWT authentication."""
144228

@@ -187,73 +271,40 @@ async def __call__(self, request: Request) -> AuthTuple:
187271
extracted from the validated JWT. Only returned on successful
188272
authentication; all error paths raise HTTPException.
189273
"""
274+
start_time = time.monotonic()
275+
190276
if not request.headers.get("Authorization"):
277+
record_auth_metrics(
278+
AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time
279+
)
191280
response = UnauthorizedResponse(cause="No Authorization header found")
192281
raise HTTPException(**response.model_dump())
193282

194-
user_token = extract_user_token(request.headers)
195-
196-
try:
197-
jwk_set = await get_jwk_set(str(self.config.url))
198-
except aiohttp.ClientError as exc:
199-
logger.error("Failed to fetch JWK set: %s", exc)
200-
response = UnauthorizedResponse(
201-
cause="Unable to reach authentication key server"
202-
)
203-
raise HTTPException(**response.model_dump()) from exc
204-
except json.JSONDecodeError as exc:
205-
logger.error("Invalid JSON in JWK set response: %s", exc)
206-
response = UnauthorizedResponse(
207-
cause="Authentication key server returned invalid data"
208-
)
209-
raise HTTPException(**response.model_dump()) from exc
210-
except JoseError as exc:
211-
logger.error("Invalid JWK set format: %s", exc)
212-
response = UnauthorizedResponse(cause="Authentication keys are malformed")
213-
raise HTTPException(**response.model_dump()) from exc
214-
215-
try:
216-
claims = jwt.decode(user_token, key=key_resolver_func(jwk_set))
217-
except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc:
218-
logger.warning("Token decode error: %s", exc)
219-
cause_map = {
220-
KeyNotFoundError: "Token signed by unknown key",
221-
BadSignatureError: "Invalid token signature",
222-
DecodeError: "Token could not be decoded",
223-
JoseError: "Token format error",
224-
}
225-
response = UnauthorizedResponse(
226-
cause=cause_map.get(type(exc), "Unknown token error")
227-
)
228-
raise HTTPException(**response.model_dump()) from exc
229-
230283
try:
231-
claims.validate()
232-
except ExpiredTokenError as exc:
233-
response = UnauthorizedResponse(cause="Token has expired")
234-
raise HTTPException(**response.model_dump()) from exc
235-
except JoseError as exc:
236-
response = UnauthorizedResponse(cause="Token validation failed")
237-
raise HTTPException(**response.model_dump()) from exc
238-
239-
try:
240-
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
241-
except KeyError as exc:
242-
missing_claim = self.config.jwt_configuration.user_id_claim
243-
response = UnauthorizedResponse(
244-
cause=f"Token missing claim: {missing_claim}"
284+
user_token = extract_user_token(request.headers)
285+
except HTTPException:
286+
record_auth_metrics(
287+
AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time
245288
)
246-
raise HTTPException(**response.model_dump()) from exc
247-
248-
try:
249-
username: str = claims[self.config.jwt_configuration.username_claim]
250-
except KeyError as exc:
251-
missing_claim = self.config.jwt_configuration.username_claim
252-
response = UnauthorizedResponse(
253-
cause=f"Token missing claim: {missing_claim}"
289+
raise
290+
except Exception: # pylint: disable=broad-exception-caught
291+
logger.exception("Unexpected error while extracting JWK bearer token")
292+
record_auth_metrics(
293+
AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", start_time
254294
)
255-
raise HTTPException(**response.model_dump()) from exc
295+
raise
296+
297+
jwk_set = await _get_jwk_set_for_auth(self.config, start_time)
298+
claims = _decode_jwk_claims(user_token, jwk_set, start_time)
299+
_validate_jwk_claims(claims, start_time)
300+
user_id = _get_required_claim(
301+
claims, self.config.jwt_configuration.user_id_claim, start_time
302+
)
303+
username = _get_required_claim(
304+
claims, self.config.jwt_configuration.username_claim, start_time
305+
)
256306

257307
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
258308

309+
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time)
259310
return user_id, username, self.skip_userid_check, user_token

0 commit comments

Comments
 (0)