Skip to content

Commit

Permalink
Add sort_headers parameter to api_jwt.encode (#832)
Browse files Browse the repository at this point in the history
* Add `sort_headers` parameter to `api_jwt.encode`

This allows you to not sort headers, which prevents a breaking change between v2.4.0 and v2.5.0

* Add `test_sorting_headers` test

* Remove outdated comment about misordered headers

* Explicity assert sorting in `test_sorting_of_headers`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Parametrize `test_sorting_of_headers`

* Use normal dict in `test_sorting_of_headers`

* fixup! Use normal dict in `test_sorting_of_headers`

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
evroon and pre-commit-ci[bot] authored Dec 8, 2022
1 parent 2fc6aa3 commit fb9b311
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
4 changes: 2 additions & 2 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def encode(
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []

Expand Down Expand Up @@ -133,9 +134,8 @@ def encode(
# True is the standard value for b64, so no need for it
del header["b64"]

# Fix for headers misorder - issue #715
json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=True
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()

segments.append(base64url_encode(json_header))
Expand Down
10 changes: 9 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def encode(
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
Expand All @@ -66,7 +67,14 @@ def encode(
json_encoder=json_encoder,
)

return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)

def _encode_payload(
self,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,17 @@ def test_bytes_secret(self, jws, payload):

assert decoded_payload == payload

@pytest.mark.parametrize("sort_headers", (False, True))
def test_sorting_of_headers(self, jws, payload, sort_headers):
jws_message = jws.encode(
payload,
key="\xc2",
headers={"b": "1", "a": "2"},
sort_headers=sort_headers,
)
header_json = base64url_decode(jws_message.split(".")[0])
assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"'))

def test_decode_invalid_header_padding(self, jws):
example_jws = (
"aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
Expand Down

0 comments on commit fb9b311

Please sign in to comment.