Skip to content

Commit

Permalink
fix: jwt encode and decode methods only works for JWS by default
Browse files Browse the repository at this point in the history
MUST provide JWERegistry to encode and decode JWT.
  • Loading branch information
lepture committed May 13, 2024
1 parent 2470d72 commit 29d391d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 20 deletions.
10 changes: 6 additions & 4 deletions docs/guide/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,20 @@ The ``JWTClaimsRegistry`` has built-in validators for timing related fields:
JWS & JWE
---------

JWT is built on top of JWS and JWE, all of the above examples are in JWS. Here
is an example of JWE:
JWT is built on top of JWS and JWE, all of the above examples are in JWS. By default
``jwt.encode`` and ``jwt.decode`` work for **JWS**. To use **JWE**, you need to specify
a ``registry`` parameter with ``JWERegistry``. Here is an example of JWE:

.. code-block:: python
from joserfc import jwt
from joserfc import jwt, jwe
from joserfc.jwk import OctKey
header = {"alg": "A128KW", "enc": "A128GCM"}
claims = {"iss": "https://authlib.org"}
key = OctKey.generate_key(128) # the algorithm requires key of 128 bit size
jwt.encode(header, claims, key)
registry = jwe.JWERegistry() # YOU MUST USE A JWERegistry
jwt.encode(header, claims, key, registry=registry)
The JWE formatted result contains 5 parts, while JWS only contains 3 parts,
a JWE example would be something like this (line breaks for display only):
Expand Down
12 changes: 2 additions & 10 deletions src/joserfc/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ def encode(
# add ``typ`` in header
_header = {"typ": "JWT", **header}
payload = convert_claims(claims)
if "enc" in _header:
if registry is not None:
assert isinstance(registry, JWERegistry)
if isinstance(registry, JWERegistry):
return encrypt_compact(_header, payload, key, algorithms, registry)
else:
if registry is not None:
assert isinstance(registry, JWSRegistry)
return serialize_compact(_header, payload, key, algorithms, registry)


Expand All @@ -87,13 +83,9 @@ def decode(
_value = to_bytes(value)
header: Header
payload: bytes
if _value.count(b".") == 4:
if registry is not None:
assert isinstance(registry, JWERegistry)
if isinstance(registry, JWERegistry):
header, payload = _decode_jwe(_value, key, algorithms, registry)
else:
if registry is not None:
assert isinstance(registry, JWSRegistry)
header, payload = _decode_jws(_value, key, algorithms, registry)

try:
Expand Down
13 changes: 7 additions & 6 deletions tests/jwt/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def test_jwe_format(self):
header = {"alg": "A128KW", "enc": "A128GCM"}
claims = {"iss": "https://authlib.org"}
key = OctKey.generate_key(128)
result = jwt.encode(header, claims, key)
registry = jwe.JWERegistry()
result = jwt.encode(header, claims, key, registry=registry)
self.assertEqual(result.count('.'), 4)

token = jwt.decode(result, key)
token = jwt.decode(result, key, registry=registry)
self.assertEqual(token.claims, claims)

def test_using_registry(self):
Expand All @@ -54,26 +55,26 @@ def test_using_registry(self):
jwt.decode(value2, key, registry=jwe.JWERegistry())

self.assertRaises(
AssertionError,
KeyError,
jwt.encode,
{"alg": "HS256"},
{"sub": "a"},
key, registry=jwe.JWERegistry(),
)
self.assertRaises(
AssertionError,
ValueError,
jwt.encode,
{"alg": "A128KW", "enc": "A128GCM"},
{"sub": "a"},
key, registry=jws.JWSRegistry(),
)
self.assertRaises(
AssertionError,
ValueError,
jwt.decode,
value1, key, registry=jwe.JWERegistry(),
)
self.assertRaises(
AssertionError,
ValueError,
jwt.decode,
value2, key, registry=jws.JWSRegistry(),
)

0 comments on commit 29d391d

Please sign in to comment.