diff --git a/jwt/api_jws.py b/jwt/api_jws.py index ab8490f9..c914db19 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -101,6 +101,7 @@ def encode( headers: dict[str, Any] | None = None, json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, + sort_headers: bool = True, ) -> str: segments = [] @@ -133,9 +134,8 @@ def encode( # True is the standard value for b64, so no need for it del header["b64"] - # Fix for headers misorder - issue #715 json_header = json.dumps( - header, separators=(",", ":"), cls=json_encoder, sort_keys=True + header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() segments.append(base64url_encode(json_header)) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index e1df3c7b..5d21afc3 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -45,6 +45,7 @@ def encode( algorithm: Optional[str] = "HS256", headers: Optional[Dict[str, Any]] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, + sort_headers: bool = True, ) -> str: # Check that we get a mapping if not isinstance(payload, Mapping): @@ -66,7 +67,14 @@ def encode( json_encoder=json_encoder, ) - return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + return api_jws.encode( + json_payload, + key, + algorithm, + headers, + json_encoder, + sort_headers=sort_headers, + ) def _encode_payload( self, diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index cfbbe212..fe6c2d4f 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -414,6 +414,17 @@ def test_bytes_secret(self, jws, payload): assert decoded_payload == payload + @pytest.mark.parametrize("sort_headers", (False, True)) + def test_sorting_of_headers(self, jws, payload, sort_headers): + jws_message = jws.encode( + payload, + key="\xc2", + headers={"b": "1", "a": "2"}, + sort_headers=sort_headers, + ) + header_json = base64url_decode(jws_message.split(".")[0]) + assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"')) + def test_decode_invalid_header_padding(self, jws): example_jws = ( "aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"