Skip to content

Commit

Permalink
Rework signing key logic to use a mapping instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasKs committed Aug 17, 2021
1 parent 97a4908 commit 03990a4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 60 deletions.
116 changes: 62 additions & 54 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,60 +80,68 @@ async def __call__(self, request: Request) -> dict[str, Any]:
Extends call to also validate the token
"""
access_token = await super().__call__(request=request)
try:
# Extract header information of the token.
header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format')

# Load new config if old
await provider_config.load_config()
for signing_key in provider_config.signing_keys:
header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature
if header.get('kid') == signing_key['kid']:
try:
# Set strict in case defaults change
options = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': True,
'require_iss': True,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': 0,
}
# Validate token and return claims
token = jwt.decode(
access_token,
key=signing_key['certificate'],
algorithms=['RS256'],
audience=f'api://{self.app_client_id}',
issuer=f'https://sts.windows.net/{provider_config.tenant_id}/',
options=options,
)
if not self.allow_guest_users and token['tid'] != provider_config.tenant_id:
raise GuestUserException()
user: User = User(**token | {'claims': token})
request.state.user = user
return token
except GuestUserException:
raise InvalidAuth('Guest users not allowed')
except JWTClaimsError as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims')
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired')
except JWTError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token')
except Exception as error:
# Extra failsafe in case of a bug in a future version of the jwt library
log.exception('Unable to process jwt token. Uncaught error: %s', error)
raise InvalidAuth(detail='Unable to process token')

# Use the `kid` from the header to find a matching signing key to use
if kid := provider_config.signing_keys.get(header.get('kid')):
try:
# We require and validate all fields in an Azure AD token
options = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': True,
'require_iss': True,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': 0,
}
# Validate token
token = jwt.decode(
access_token,
key=kid,
algorithms=['RS256'],
audience=f'api://{self.app_client_id}',
issuer=f'https://sts.windows.net/{provider_config.tenant_id}/',
options=options,
)
if not self.allow_guest_users and token['tid'] != provider_config.tenant_id:
raise GuestUserException()
# Attach the user to the request. Can be accessed through `request.state.user`
user: User = User(**token | {'claims': token})
request.state.user = user
return token
except GuestUserException:
raise InvalidAuth('Guest users not allowed')
except JWTClaimsError as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims')
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired')
except JWTError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token')
except Exception as error:
# Extra failsafe in case of a bug in a future version of the jwt library
log.exception('Unable to process jwt token. Uncaught error: %s', error)
raise InvalidAuth(detail='Unable to process token')
raise InvalidAuth(detail='Unable to verify token, no signing keys found')
10 changes: 4 additions & 6 deletions fastapi_azure_auth/provider_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self) -> None:
self._config_timestamp: Optional[datetime] = None

self.authorization_endpoint: str
self.signing_keys: list[Key]
self.signing_keys: dict[str, KeyTypes]
self.token_endpoint: str
self.end_session_endpoint: str
self.issuer: str
Expand Down Expand Up @@ -84,15 +84,13 @@ def _load_keys(self, keys: list[dict]) -> None:
"""
Create certificates based on signing keys and store them
"""
new_keys = []
self.signing_keys = {}
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = load_der_x509_certificate(base64.b64decode(key['x5c'][0]), backend)
if key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
new_key: Key = {'kid': key['kid'], 'certificate': cert_obj.public_key()}
new_keys.append(new_key)
self.signing_keys = new_keys
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
self.signing_keys[kid] = cert_obj.public_key()


provider_config = ProviderConfig()
23 changes: 23 additions & 0 deletions tests/test_validate_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,29 @@ async def test_evil_token(mock_openid_and_keys):
assert response.json() == {'detail': 'Unable to validate token'}


async def test_malformed_token(mock_openid_and_keys):
"""A short token, that only has a broken header"""
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Invalid token format'}


async def test_only_header(mock_openid_and_keys):
"""Only header token, with a matching kid, so the rest of the logic will be called, but can't be validated"""
async with AsyncClient(
app=app,
base_url='http://test',
headers={
'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6InJlYWwgdGh1bWJ'
'wcmludCIsInR5cCI6IkpXVCIsIng1dCI6ImFub3RoZXIgdGh1bWJwcmludCJ9'
}, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Unable to validate token'}


async def test_exception_raised(mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
async with AsyncClient(
Expand Down

0 comments on commit 03990a4

Please sign in to comment.