diff --git a/fastapi_azure_auth/auth.py b/fastapi_azure_auth/auth.py index 02a140e..45db203 100644 --- a/fastapi_azure_auth/auth.py +++ b/fastapi_azure_auth/auth.py @@ -80,60 +80,68 @@ async def __call__(self, request: Request) -> dict[str, Any]: Extends call to also validate the token """ access_token = await super().__call__(request=request) + try: + # Extract header information of the token. + header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature + except Exception as error: + log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True) + raise InvalidAuth(detail='Invalid token format') + # Load new config if old await provider_config.load_config() - for signing_key in provider_config.signing_keys: - header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature - if header.get('kid') == signing_key['kid']: - try: - # Set strict in case defaults change - options = { - 'verify_signature': True, - 'verify_aud': True, - 'verify_iat': True, - 'verify_exp': True, - 'verify_nbf': True, - 'verify_iss': True, - 'verify_sub': True, - 'verify_jti': True, - 'verify_at_hash': True, - 'require_aud': True, - 'require_iat': True, - 'require_exp': True, - 'require_nbf': True, - 'require_iss': True, - 'require_sub': True, - 'require_jti': False, - 'require_at_hash': False, - 'leeway': 0, - } - # Validate token and return claims - token = jwt.decode( - access_token, - key=signing_key['certificate'], - algorithms=['RS256'], - audience=f'api://{self.app_client_id}', - issuer=f'https://sts.windows.net/{provider_config.tenant_id}/', - options=options, - ) - if not self.allow_guest_users and token['tid'] != provider_config.tenant_id: - raise GuestUserException() - user: User = User(**token | {'claims': token}) - request.state.user = user - return token - except GuestUserException: - raise InvalidAuth('Guest users not allowed') - except JWTClaimsError as error: - log.info('Token contains invalid claims. %s', error) - raise InvalidAuth(detail='Token contains invalid claims') - except ExpiredSignatureError as error: - log.info('Token signature has expired. %s', error) - raise InvalidAuth(detail='Token signature has expired') - except JWTError as error: - log.warning('Invalid token. Error: %s', error, exc_info=True) - raise InvalidAuth(detail='Unable to validate token') - except Exception as error: - # Extra failsafe in case of a bug in a future version of the jwt library - log.exception('Unable to process jwt token. Uncaught error: %s', error) - raise InvalidAuth(detail='Unable to process token') + + # Use the `kid` from the header to find a matching signing key to use + if kid := provider_config.signing_keys.get(header.get('kid')): + try: + # We require and validate all fields in an Azure AD token + options = { + 'verify_signature': True, + 'verify_aud': True, + 'verify_iat': True, + 'verify_exp': True, + 'verify_nbf': True, + 'verify_iss': True, + 'verify_sub': True, + 'verify_jti': True, + 'verify_at_hash': True, + 'require_aud': True, + 'require_iat': True, + 'require_exp': True, + 'require_nbf': True, + 'require_iss': True, + 'require_sub': True, + 'require_jti': False, + 'require_at_hash': False, + 'leeway': 0, + } + # Validate token + token = jwt.decode( + access_token, + key=kid, + algorithms=['RS256'], + audience=f'api://{self.app_client_id}', + issuer=f'https://sts.windows.net/{provider_config.tenant_id}/', + options=options, + ) + if not self.allow_guest_users and token['tid'] != provider_config.tenant_id: + raise GuestUserException() + # Attach the user to the request. Can be accessed through `request.state.user` + user: User = User(**token | {'claims': token}) + request.state.user = user + return token + except GuestUserException: + raise InvalidAuth('Guest users not allowed') + except JWTClaimsError as error: + log.info('Token contains invalid claims. %s', error) + raise InvalidAuth(detail='Token contains invalid claims') + except ExpiredSignatureError as error: + log.info('Token signature has expired. %s', error) + raise InvalidAuth(detail='Token signature has expired') + except JWTError as error: + log.warning('Invalid token. Error: %s', error, exc_info=True) + raise InvalidAuth(detail='Unable to validate token') + except Exception as error: + # Extra failsafe in case of a bug in a future version of the jwt library + log.exception('Unable to process jwt token. Uncaught error: %s', error) + raise InvalidAuth(detail='Unable to process token') raise InvalidAuth(detail='Unable to verify token, no signing keys found') diff --git a/fastapi_azure_auth/provider_config.py b/fastapi_azure_auth/provider_config.py index c6bef1d..18b3675 100644 --- a/fastapi_azure_auth/provider_config.py +++ b/fastapi_azure_auth/provider_config.py @@ -23,7 +23,7 @@ def __init__(self) -> None: self._config_timestamp: Optional[datetime] = None self.authorization_endpoint: str - self.signing_keys: list[Key] + self.signing_keys: dict[str, KeyTypes] self.token_endpoint: str self.end_session_endpoint: str self.issuer: str @@ -84,15 +84,13 @@ def _load_keys(self, keys: list[dict]) -> None: """ Create certificates based on signing keys and store them """ - new_keys = [] + self.signing_keys = {} for key in keys: if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption log.debug('Loading public key from certificate: %s', key) cert_obj = load_der_x509_certificate(base64.b64decode(key['x5c'][0]), backend) - if key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it. - new_key: Key = {'kid': key['kid'], 'certificate': cert_obj.public_key()} - new_keys.append(new_key) - self.signing_keys = new_keys + if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it. + self.signing_keys[kid] = cert_obj.public_key() provider_config = ProviderConfig() diff --git a/tests/test_validate_token.py b/tests/test_validate_token.py index ccc1a36..9ebf86e 100644 --- a/tests/test_validate_token.py +++ b/tests/test_validate_token.py @@ -125,6 +125,29 @@ async def test_evil_token(mock_openid_and_keys): assert response.json() == {'detail': 'Unable to validate token'} +async def test_malformed_token(mock_openid_and_keys): + """A short token, that only has a broken header""" + async with AsyncClient( + app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Invalid token format'} + + +async def test_only_header(mock_openid_and_keys): + """Only header token, with a matching kid, so the rest of the logic will be called, but can't be validated""" + async with AsyncClient( + app=app, + base_url='http://test', + headers={ + 'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6InJlYWwgdGh1bWJ' + 'wcmludCIsInR5cCI6IkpXVCIsIng1dCI6ImFub3RoZXIgdGh1bWJwcmludCJ9' + }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Unable to validate token'} + + async def test_exception_raised(mock_openid_and_keys, mocker): mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol')) async with AsyncClient(