Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWK support to JWT encode #979

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Changed

- Use ``Sequence`` for parameter types rather than ``List`` where applicable by @imnotjames in `#970 <https://github.com/jpadilla/pyjwt/pull/970>`__
- Remove algorithm requirement from JWT API, instead relying on JWS API for enforcement, by @luhn in `#975 <https://github.com/jpadilla/pyjwt/pull/975>`__
- Add JWK support to JWT encode by @luhn in `#979 <https://github.com/jpadilla/pyjwt/pull/979>`__

Fixed
~~~~~
Expand Down
14 changes: 11 additions & 3 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
def encode(
self,
payload: bytes,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = None,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
Expand All @@ -115,7 +115,13 @@ def encode(
segments = []

# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"
if algorithm is None:
if isinstance(key, PyJWK):
algorithm_ = key.algorithm_name
else:
algorithm_ = "HS256"
else:
algorithm_ = algorithm

# Prefer headers values if present to function parameters.
if headers:
Expand Down Expand Up @@ -159,6 +165,8 @@ def encode(
signing_input = b".".join(segments)

alg_obj = self.get_algorithm_by_name(algorithm_)
if isinstance(key, PyJWK):
key = key.key
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

Expand Down
4 changes: 2 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = None,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
Expand Down
35 changes: 33 additions & 2 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
exception = context.value
assert str(exception) == "Invalid header string: must be a json object"

def test_encode_default_algorithm(self, jws, payload):
msg = jws.encode(payload, "secret")
decoded = jws.decode_complete(msg, "secret", algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
jws.encode(payload, "secret", algorithm="HS256")

Expand Down Expand Up @@ -193,6 +205,25 @@ def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])

def test_encode_with_jwk(self, jws, payload):
jwk = PyJWK(
{
"kty": "oct",
"alg": "HS256",
"k": "c2VjcmV0", # "secret"
}
)
msg = jws.encode(payload, key=jwk)
decoded = jws.decode_complete(msg, key=jwk, algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
Expand Down Expand Up @@ -531,13 +562,13 @@ def test_decode_invalid_crypto_padding(self, jws):
assert "Invalid crypto padding" in str(exc.value)

def test_decode_with_algo_none_should_fail(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm="none")

with pytest.raises(DecodeError):
jws.decode(jws_message, algorithms=["none"])

def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm="none")
jws.decode(jws_message, options={"verify_signature": False})

def test_get_unverified_header_returns_header_values(self, jws, payload):
Expand Down
Loading