Skip to content

Commit

Permalink
Encoding EC keys with a fixed bit length
Browse files Browse the repository at this point in the history
  • Loading branch information
way-dave committed Oct 10, 2024
1 parent 6c7cc61 commit 91a20e2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Changed
Fixed
~~~~~

- Encode EC keys with a fixed bit length by @etianen in `#990 <https://github.com/jpadilla/pyjwt/pull/990>`__

Added
~~~~~

Expand Down
13 changes: 10 additions & 3 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
"y": to_base64url_uint(public_numbers.y).decode(),
"x": to_base64url_uint(
public_numbers.x,
bit_length=key_obj.curve.key_size,
).decode(),
"y": to_base64url_uint(
public_numbers.y,
bit_length=key_obj.curve.key_size,
).decode(),
}

if isinstance(key_obj, EllipticCurvePrivateKey):
obj["d"] = to_base64url_uint(
key_obj.private_numbers().private_value
key_obj.private_numbers().private_value,
bit_length=key_obj.curve.key_size,
).decode()

if as_dict:
Expand Down
17 changes: 7 additions & 10 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import binascii
import re
from typing import Union
from typing import Optional, Union

try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
Expand Down Expand Up @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")


def to_base64url_uint(val: int) -> bytes:
def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes:
if val < 0:
raise ValueError("Must be a positive integer")

int_bytes = bytes_from_int(val)
int_bytes = bytes_from_int(val, bit_length=bit_length)

if len(int_bytes) == 0:
int_bytes = b"\x00"
Expand All @@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int:
return int(binascii.b2a_hex(string), 16)


def bytes_from_int(val: int) -> bytes:
remaining = val
byte_length = 0

while remaining != 0:
remaining >>= 8
byte_length += 1
def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes:
if bit_length is None:
bit_length = val.bit_length()
byte_length = (bit_length + 7) // 8

return val.to_bytes(byte_length, "big", signed=False)

Expand Down

0 comments on commit 91a20e2

Please sign in to comment.