Skip to content

Commit

Permalink
Replace pyjwkest with pyjwt (#32270)
Browse files Browse the repository at this point in the history
* chore: replace pyjwkest with pyjwt
  • Loading branch information
mumarkhan999 authored Oct 18, 2023
1 parent 7ef8650 commit 92731be
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 145 deletions.
48 changes: 29 additions & 19 deletions lms/envs/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,25 +517,35 @@

####################### Authentication Settings ##########################
JWT_AUTH.update({
'JWT_PUBLIC_SIGNING_JWK_SET': (
'{"keys": [{"kid": "BTZ9HA6K", "e": "AQAB", "kty": "RSA", "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6'
'sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc'
'4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEu'
'lLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"}]}'
),
'JWT_PRIVATE_SIGNING_JWK': (
'{"e": "AQAB", "d": "HIiV7KNjcdhVbpn3KT-I9n3JPf5YbGXsCIedmPqDH1d4QhBofuAqZ9zebQuxkRUpmqtYMv0Zi6ECSUqH387GYQF_Xv'
'FUFcjQRPycISd8TH0DAKaDpGr-AYNshnKiEtQpINhcP44I1AYNPCwyoxXA1fGTtmkKChsuWea7o8kytwU5xSejvh5-jiqu2SF4GEl0BEXIAPZs'
'gbzoPIWNxgO4_RzNnWs6nJZeszcaDD0CyezVSuH9QcI6g5QFzAC_YuykSsaaFJhZ05DocBsLczShJ9Omf6PnK9xlm26I84xrEh_7x4fVmNBg3x'
'WTLh8qOnHqGko93A1diLRCrKHOvnpvgQ", "n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0H'
'ChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwO'
'n5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5'
'q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ", "q": "3T3DEtBUka7hLGdIsDlC96Uadx_q_E4Vb1cxx_4Ss_wGp1Lo'
'z3N3ZngGyInsKlmbBgLo1Ykd6T9TRvRNEWEtFSOcm2INIBoVoXk7W5RuPa8Cgq2tjQj9ziGQ08JMejrPlj3Q1wmALJr5VTfvSYBu0WkljhKNCy'
'1KB6fCby0C9WE", "p": "vUqzWPZnDG4IXyo-k5F0bHV0BNL_pVhQoLW7eyFHnw74IOEfSbdsMspNcPSFIrtgPsn7981qv3lN_staZ6JflKfH'
'ayjB_lvltHyZxfl0dvruShZOx1N6ykEo7YrAskC_qxUyrIvqmJ64zPW3jkuOYrFs7Ykj3zFx3Zq1H5568G0", "kid": "BTZ9HA6K", "kty"'
': "RSA"}'
),
'JWT_PUBLIC_SIGNING_JWK_SET': """
{
"keys":[
{
"kid":"BTZ9HA6K",
"e":"AQAB",
"kty":"RSA",
"n":"o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ"
}
]
}
""",
'JWT_PRIVATE_SIGNING_JWK': """
{
"kid": "BTZ9HA6K",
"kty": "RSA",
"key_ops": [
"sign"
],
"n": "o5cn3ljSRi6FaDEKTn0PS-oL9EFyv1pI7dRgffQLD1qf5D6sprmYfWWokSsrWig8u2y0HChSygR6Jn5KXBqQn6FpM0dDJLnWQDRXHLl3Ey1iPYgDSmOIsIGrV9ZyNCQwk03wAgWbfdBTig3QSDYD-sTNOs3pc4UD_PqAvU2nz_1SS2ZiOwOn5F6gulE1L0iE3KEUEvOIagfHNVhz0oxa_VRZILkzV-zr6R_TW1m97h4H8jXl_VJyQGyhMGGypuDrQ9_vaY_RLEulLCyY0INglHWQ7pckxBtI5q55-Vio2wgewe2_qYcGsnBGaDNbySAsvYcWRrqDiFyzrJYivodqTQ",
"e": "AQAB",
"d": "HIiV7KNjcdhVbpn3KT-I9n3JPf5YbGXsCIedmPqDH1d4QhBofuAqZ9zebQuxkRUpmqtYMv0Zi6ECSUqH387GYQF_XvFUFcjQRPycISd8TH0DAKaDpGr-AYNshnKiEtQpINhcP44I1AYNPCwyoxXA1fGTtmkKChsuWea7o8kytwU5xSejvh5-jiqu2SF4GEl0BEXIAPZsgbzoPIWNxgO4_RzNnWs6nJZeszcaDD0CyezVSuH9QcI6g5QFzAC_YuykSsaaFJhZ05DocBsLczShJ9Omf6PnK9xlm26I84xrEh_7x4fVmNBg3xWTLh8qOnHqGko93A1diLRCrKHOvnpvgQ",
"p": "3T3DEtBUka7hLGdIsDlC96Uadx_q_E4Vb1cxx_4Ss_wGp1Loz3N3ZngGyInsKlmbBgLo1Ykd6T9TRvRNEWEtFSOcm2INIBoVoXk7W5RuPa8Cgq2tjQj9ziGQ08JMejrPlj3Q1wmALJr5VTfvSYBu0WkljhKNCy1KB6fCby0C9WE",
"q": "vUqzWPZnDG4IXyo-k5F0bHV0BNL_pVhQoLW7eyFHnw74IOEfSbdsMspNcPSFIrtgPsn7981qv3lN_staZ6JflKfHayjB_lvltHyZxfl0dvruShZOx1N6ykEo7YrAskC_qxUyrIvqmJ64zPW3jkuOYrFs7Ykj3zFx3Zq1H5568G0",
"dp": "Azh08H8r2_sJuBXAzx_mQ6iZnAZQ619PnJFOXjTqnMgcaK8iSHLL2CgDIUQwteUcBphgP0uBrfWIBs5jmM8rUtVz4CcrPb5jdjhHjuu4NxmnFbPlhNoOp8OBUjPP3S-h-fPoaFjxDrUqz_zCdPVzp4S6UTkf6Hu-SiI9CFVFZ8E",
"dq": "WQ44_KTIbIej9qnYUPMA1DoaAF8ImVDIdiOp9c79dC7FvCpN3w-lnuugrYDM1j9Tk5bRrY7-JuE6OaKQgOtajoS1BIxjYHj5xAVPD15CVevOihqeq5Zx0ZAAYmmCKRrfUe0iLx2QnIcoKH1-Azs23OXeeo6nysznZjvv9NVJv60",
"qi": "KSWGH607H1kNG2okjYdmVdNgLxTUB-Wye9a9FNFE49UmQIOJeZYXtDzcjk8IiK3g-EU3CqBeDKVUgHvHFu4_Wj3IrIhKYizS4BeFmOcPDvylDQCmJcC9tXLQgHkxM_MEJ7iLn9FOLRshh7GPgZphXxMhezM26Cz-8r3_mACHu84"
}
""",
})
# pylint: enable=unicode-format-string # lint-amnesty, pylint: disable=bad-option-value
####################### Plugin Settings ##########################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The code examples below show this in action.
Remove JWT_ISSUERS
~~~~~~~~~~~~~~~~~~

edx_rest_framework_extensions.settings_ supports having a list of **JWT_ISSUERS** instead of just a single
`edx_rest_framework_extensions.settings`_ supports having a list of **JWT_ISSUERS** instead of just a single
one. This support for configuring multiple issuers is present across many services. However, this does not
conform to the `JWT standard`_, where the `issuer`_ is intended to identify the entity that generates and
signs the JWT. In our case, that should be the single Auth service only.
Expand All @@ -81,70 +81,56 @@ issuer, but with (the potential of) multiple signing keys stored in a JWT Set.
.. _JSON Web Key Set (JWK Set): https://tools.ietf.org/html/draft-ietf-jose-json-web-key-36#section-5
.. _site configuration: https://github.com/openedx/edx-platform/blob/af841336c7e39d634c238cd8a11c5a3a661aa9e2/openedx/core/djangoapps/site_configuration/__init__.py

Example Code
------------
Features
--------

KeyPair Generation
~~~~~~~~~~~~~~~~~~

Here is code for generating a keypair::

from Cryptodome.PublicKey import RSA
from jwkest import jwk

rsa_key = RSA.generate(2048)
rsa_jwk = jwk.RSAKey(kid="your_key_id", key=rsa_key)

To serialize the **public key** in a `JSON Web Key Set (JWK Set)`_::

public_keys = jwk.KEYS()
public_keys.append(rsa_jwk)
serialized_public_keys_json = public_keys.dump_jwks()

and its sample output::

{
"keys": [
{
"kid": "your_key_id",
"e": "strawberry",
"kty": "RSA",
"n": "something"
}
]
}

To serialize the **keypair** as a JWK::

serialized_keypair = rsa_jwk.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)

and its sample output::

{
"e": "strawberry",
"d": "apple",
"n": "banana",
"q": "pear",
"p": "plum",
"kid": "your_key_id",
"kty": "RSA"
}
Please have a look at ``openedx/core/djangoapps/oauth_dispatch/management/commands/generate_jwt_signing_key.py``
to get better understanding how to generate keypair using ``PyJWT``.

The public and private keypair would be similar to the following::

## Public keyset
"""
{
"keys": [
{
"kty": "RSA",
"key_ops": ["verify"],
"n": "...",
"e": "...",
"kid": "your_key_id"
}
]
}
"""


## Private key
"""
{
"kty": "RSA",
"key_ops": ["sign"],
"n": "...",
"e": "...",
"d": "...",
"p": "...",
"q": "...",
"dp": "...",
"dq": "...",
"qi": "...",
"kid": "your_key_id"
}
"""

Signing
~~~~~~~

To deserialize the keypair from above::

private_keys = jwk.KEYS()
serialized_keypair = json.loads(serialized_keypair_json)
private_keys.add(serialized_keypair)
To create a signature you simply need a **payload**, **private key** and your hashing algorithm::

To create a signature::

from jwkest.jws import JWS
jws = JWS("JWT payload", alg="RS512")
signed_message = jws.sign_compact(keys=private_keys)
signed_message = jwt.encode("JWT payload in dict format", key=private_key, algorithm="RS512")

Note: we specify **RS512** above to identify *RSASSA-PKCS1-v1_5 using SHA-512* as
the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ spec.
Expand All @@ -154,24 +140,20 @@ the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ s
Verify Signature
~~~~~~~~~~~~~~~~

To verify the signature from above::
To verify the signature we'll be looping through the public keys and try to verify the signature with each of them.
For more details you can have a look at `verify_jwk_signature_using_keyset`_. To generate ``keyset`` required for verification you
can use `get_verification_jwk_key_set`_ method.

public_keys = jwk.KEYS()
public_keys.load_jwks(serialized_public_keys_json)
jws.verify_compact(signed_message, public_keys)
.. _verify_jwk_signature_using_keyset: https://github.com/openedx/edx-drf-extensions/blob/master/edx_rest_framework_extensions/auth/jwt/decoder.py#L270
.. _get_verification_jwk_key_set : https://github.com/openedx/edx-drf-extensions/blob/master/edx_rest_framework_extensions/auth/jwt/decoder.py#L395

Key Rotation
~~~~~~~~~~~~

When a new public key is added in the future, it should have a unique "kid"
value and added to the public keys JWK set::

new_rsa_key = RSA.generate(2048)
new_rsa_jwk = jwk.RSAKey(kid="new_id", key=new_rsa_key)
public_keys.append(new_rsa_jwk)

When a JWS is created, it is signed with a certain "kid"-identified keypair. When it
is later verified, the public key with the matching "kid" in the JWK set is used.
In future if we plan to rotate the keys, we can simply add new key public key to the public keyset and remove the old private one.
Means, at any time there might be more than one public key but there will be only one private key. Considering that we are doing verification
by looping through all the available public keys, the ``kid`` parameter is not
as important as it was before. But it's still recommended to use it. It will help us to differentiate between the old and new public keys.

Consequences
------------
Expand Down
18 changes: 8 additions & 10 deletions openedx/core/djangoapps/oauth_dispatch/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import logging
from time import time

import jwt
from django.conf import settings
from edx_django_utils.monitoring import increment, set_custom_attribute
from edx_rbac.utils import create_role_auth_claim_for_user
from edx_toggles.toggles import SettingToggle
from jwkest import jwk
from jwkest.jws import JWS
from jwt import PyJWK
from jwt.utils import base64url_encode

from common.djangoapps.student.models import UserProfile, anonymous_id_for_user

Expand Down Expand Up @@ -273,17 +274,14 @@ def _attach_profile_claim(payload, user):

def _encode_and_sign(payload, use_asymmetric_key, secret):
"""Encode and sign the provided payload."""
keys = jwk.KEYS()

if use_asymmetric_key:
serialized_keypair = json.loads(settings.JWT_AUTH['JWT_PRIVATE_SIGNING_JWK'])
keys.add(serialized_keypair)
key = json.loads(settings.JWT_AUTH['JWT_PRIVATE_SIGNING_JWK'])
algorithm = settings.JWT_AUTH['JWT_SIGNING_ALGORITHM']
else:
key = secret if secret else settings.JWT_AUTH['JWT_SECRET_KEY']
keys.add({'key': key, 'kty': 'oct'})
secret = secret if secret else settings.JWT_AUTH['JWT_SECRET_KEY']
key = {'k': base64url_encode(secret.encode('utf-8')), 'kty': 'oct'}
algorithm = settings.JWT_AUTH['JWT_ALGORITHM']

data = json.dumps(payload)
jws = JWS(data, alg=algorithm)
return jws.sign_compact(keys=keys)
jwk = PyJWK(key, algorithm)
return jwt.encode(payload, jwk.key, algorithm=algorithm)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.core.management.base import BaseCommand
from jwkest import jwk
from jwt.algorithms import get_default_algorithms

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,15 +123,23 @@ def _generate_key_id(self, size, chars=string.ascii_uppercase + string.digits):
def _generate_key_pair(self, key_size, key_id):
log.info('Generating new JWT signing keypair for key id %s.', key_id)
rsa_key = RSA.generate(key_size)
rsa_jwk = jwk.RSAKey(kid=key_id, key=rsa_key)
return rsa_jwk
algo = get_default_algorithms()['RS512']
key_data = algo.prepare_key(rsa_key.export_key('PEM').decode())
rsa_jwk = json.loads(algo.to_jwk(key_data))
public_rsa_jwk = json.loads(algo.to_jwk(key_data.public_key()))

rsa_jwk['kid'] = key_id
public_rsa_jwk['kid'] = key_id
return {'private': rsa_jwk, 'public': public_rsa_jwk}

def _output_public_keys(self, jwk_key, add_previous, strip_prefix):
public_keys = jwk.KEYS()
public_keys = {'keys': []}

if add_previous:
self._add_previous_public_keys(public_keys)
public_keys.append(jwk_key)
serialized_public_keys = public_keys.dump_jwks()

public_keys['keys'].append(jwk_key['public'])
serialized_public_keys = json.dumps(public_keys)

prefix = '' if strip_prefix else 'COMMON_'
public_signing_key = f'{prefix}JWT_PUBLIC_SIGNING_JWK_SET'
Expand All @@ -155,11 +163,10 @@ def _add_previous_public_keys(self, public_keys):
previous_signing_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
if previous_signing_keys:
log.info('Old JWT_PUBLIC_SIGNING_JWK_SET: %s.', previous_signing_keys)
public_keys.load_jwks(previous_signing_keys)
public_keys['keys'].extend(json.loads(previous_signing_keys)['keys'])

def _output_private_keys(self, jwk_key, strip_prefix):
serialized_keypair = jwk_key.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
serialized_keypair_json = json.dumps(jwk_key['private'])

prefix = '' if strip_prefix else 'EDXAPP_'
private_signing_key = f'{prefix}JWT_PRIVATE_SIGNING_JWK'
Expand Down
35 changes: 13 additions & 22 deletions openedx/core/djangoapps/oauth_dispatch/tests/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

import pytest
import jwt
from django.conf import settings
from jwkest.jwk import KEYS
from jwkest.jws import JWS
from edx_rest_framework_extensions.auth.jwt.decoder import (
get_verification_jwk_key_set,
verify_jwk_signature_using_keyset
)
from jwt.exceptions import ExpiredSignatureError

from common.djangoapps.student.models import UserProfile, anonymous_id_for_user
Expand All @@ -33,25 +34,15 @@ def _decode_jwt(verify_expiration):
Helper method to decode a JWT with the ability to
verify the expiration of said token
"""
keys = KEYS()
if should_be_asymmetric_key:
keys.load_jwks(settings.JWT_AUTH['JWT_PUBLIC_SIGNING_JWK_SET'])
else:
keys.add({'key': secret_key, 'kty': 'oct'})

_ = JWS().verify_compact(access_token.encode('utf-8'), keys)

return jwt.decode(
access_token,
secret_key,
algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']],
audience=audience,
issuer=issuer,
options={
'verify_signature': False,
"verify_exp": verify_expiration
},
)
asymmetric_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET') if should_be_asymmetric_key else None
key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys, secret_key=secret_key)
data = verify_jwk_signature_using_keyset(access_token,
key_set,
iss=issuer,
aud=aud,
verify_exp=verify_expiration)

return data

# Note that if we expect the claims to have expired
# then we ask the JWT library not to verify expiration
Expand Down
16 changes: 0 additions & 16 deletions openedx/core/djangoapps/oauth_dispatch/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import ddt
import httpretty
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.test import RequestFactory, TestCase
from django.urls import reverse
from edx_toggles.toggles.testutils import override_waffle_switch
from jwkest import jwk
from oauth2_provider import models as dot_models

from common.djangoapps.student.tests.factories import UserFactory
Expand Down Expand Up @@ -164,20 +162,6 @@ def _post_body(self, user, client, token_type=None, scope=None, asymmetric_jwt=N

return body

def _generate_key_pair(self):
""" Generates an asymmetric key pair and returns the JWK of its public keys and keypair. """
rsa_key = RSA.generate(2048)
rsa_jwk = jwk.RSAKey(kid="key_id", key=rsa_key)

public_keys = jwk.KEYS()
public_keys.append(rsa_jwk)
serialized_public_keys_json = public_keys.dump_jwks()

serialized_keypair = rsa_jwk.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)

return serialized_public_keys_json, serialized_keypair_json

def _test_jwt_access_token(self, client_attr, token_type=None, headers=None, grant_type=None, asymmetric_jwt=False):
"""
Test response for JWT token.
Expand Down

0 comments on commit 92731be

Please sign in to comment.