From 1ebcb2e3d74a892987a10c9d54e079fda525962d Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 1 Apr 2023 12:55:11 +0200 Subject: [PATCH 01/20] Use new style typing --- jwt/api_jwk.py | 12 +++++----- jwt/api_jws.py | 8 +++---- jwt/api_jwt.py | 54 +++++++++++++++++++++---------------------- tests/test_api_jws.py | 2 +- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index fdcde21a..dd876f73 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -2,7 +2,7 @@ import json import time -from typing import Any, Optional +from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -10,7 +10,7 @@ class PyJWK: - def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None: + def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None: self._algorithms = get_default_algorithms() self._jwk_data = jwk_data @@ -60,7 +60,7 @@ def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None: self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK": + def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod @@ -69,15 +69,15 @@ def from_json(data: str, algorithm: None = None) -> "PyJWK": return PyJWK.from_dict(obj, algorithm) @property - def key_type(self) -> Optional[str]: + def key_type(self) -> str | None: return self._jwk_data.get("kty", None) @property - def key_id(self) -> Optional[str]: + def key_id(self) -> str | None: return self._jwk_data.get("kid", None) @property - def public_key_use(self) -> Optional[str]: + def public_key_use(self) -> str | None: return self._jwk_data.get("use", None) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 0a775d7a..ade8d688 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ import binascii import json import warnings -from typing import Any, List, Optional, Type +from typing import Any from .algorithms import ( Algorithm, @@ -26,8 +26,8 @@ class PyJWS: def __init__( self, - algorithms: Optional[List[str]] = None, - options: Optional[dict[str, Any]] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( @@ -103,7 +103,7 @@ def encode( key: str, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, - json_encoder: Type[json.JSONEncoder] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, sort_headers: bool = True, ) -> str: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index d85f6e8b..f50b1f49 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any from . import api_jws from .exceptions import ( @@ -21,13 +21,13 @@ class PyJWT: - def __init__(self, options: Optional[dict[str, Any]] = None) -> None: + def __init__(self, options: dict[str, Any] | None = None) -> None: if options is None: options = {} self.options: dict[str, Any] = {**self._get_default_options(), **options} @staticmethod - def _get_default_options() -> Dict[str, Union[bool, List[str]]]: + def _get_default_options() -> dict[str, bool | list[str]]: return { "verify_signature": True, "verify_exp": True, @@ -40,11 +40,11 @@ def _get_default_options() -> Dict[str, Union[bool, List[str]]]: def encode( self, - payload: Dict[str, Any], + payload: dict[str, Any], key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, ) -> str: # Check that we get a dict @@ -78,9 +78,9 @@ def encode( def _encode_payload( self, - payload: Dict[str, Any], - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, ) -> bytes: """ Encode a given payload to the bytes to be signed. @@ -98,20 +98,20 @@ def decode_complete( self, jwt: str | bytes, key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -163,7 +163,7 @@ def decode_complete( decoded["payload"] = payload return decoded - def _decode_payload(self, decoded: Dict[str, Any]) -> Any: + def _decode_payload(self, decoded: dict[str, Any]) -> Any: """ Decode the payload from a JWS dictionary (payload, signature, header). @@ -183,17 +183,17 @@ def decode( self, jwt: str | bytes, key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs **kwargs, ) -> Any: @@ -303,7 +303,7 @@ def _validate_exp( def _validate_aud( self, payload: dict[str, Any], - audience: Optional[Union[str, Iterable[str]]], + audience: str | Iterable[str] | None, ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index d2aa9159..b961dc2e 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -82,7 +82,7 @@ def test_non_object_options_dont_persist(self, jws, payload): assert jws.options["verify_signature"] - def test_options_must_be_dict(self, jws): + def test_options_must_be_dict(self): pytest.raises(TypeError, PyJWS, options=object()) pytest.raises((TypeError, ValueError), PyJWS, options=("something")) From 7cf5a5422f031dd3f1169cd41c829ff0bba60227 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 1 Apr 2023 18:58:25 +0200 Subject: [PATCH 02/20] Fix type annotations to allow all keys --- jwt/algorithms.py | 176 ++++++++++++++++++++++++++++------------------ jwt/api_jws.py | 17 +++-- jwt/api_jwt.py | 15 ++-- jwt/utils.py | 4 +- 4 files changed, 129 insertions(+), 83 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index bc928fec..9ce45f5c 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -2,7 +2,18 @@ import hmac import json from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + ClassVar, + Dict, + NoReturn, + Optional, + Type, + Union, + cast, +) from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -19,7 +30,6 @@ ) try: - import cryptography.exceptions from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes @@ -68,6 +78,23 @@ except ModuleNotFoundError: has_crypto = False + +if TYPE_CHECKING: + # Type aliases for convenience in algorithms method signatures + AllowedRSAKeys = Union[RSAPrivateKey, RSAPublicKey] + AllowedECKeys = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey] + AllowedOKPKeys = Union[ + Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey + ] + AllowedKeys = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys] + AllowedPrivateKeys = Union[ + RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey + ] + AllowedPublicKeys = Union[ + RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey + ] + + requires_cryptography = { "RS256", "RS384", @@ -172,16 +199,16 @@ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: @staticmethod @abstractmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj) -> str: """ - Serializes a given RSA key into a JWK + Serializes a given key into a JWK """ @staticmethod @abstractmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: Union[str, JWKDict]) -> Any: """ - Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + Deserializes a given key from JWK back into a key object """ @@ -191,7 +218,7 @@ class NoneAlgorithm(Algorithm): operations are required. """ - def prepare_key(self, key): + def prepare_key(self, key: Optional[str]) -> None: if key == "": key = None @@ -200,18 +227,18 @@ def prepare_key(self, key): return key - def sign(self, msg, key): + def sign(self, msg: bytes, key: None) -> bytes: return b"" - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: None, sig: bytes) -> bool: return False @staticmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj: Any) -> NoReturn: raise NotImplementedError() @staticmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: Union[str, JWKDict]) -> NoReturn: raise NotImplementedError() @@ -228,19 +255,19 @@ class HMACAlgorithm(Algorithm): def __init__(self, hash_alg: HashlibHash) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): - key = force_bytes(key) + def prepare_key(self, key: AnyStr) -> bytes: + key_bytes = force_bytes(key) - if is_pem_format(key) or is_ssh_key(key): + if is_pem_format(key_bytes) or is_ssh_key(key_bytes): raise InvalidKeyError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." ) - return key + return key_bytes @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: AnyStr) -> str: return json.dumps( { "k": base64url_encode(force_bytes(key_obj)).decode(), @@ -249,10 +276,10 @@ def to_jwk(key_obj): ) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: Union[str, JWKDict]) -> bytes: try: if isinstance(jwk, str): - obj = json.loads(jwk) + obj: JWKDict = json.loads(jwk) elif isinstance(jwk, dict): obj = jwk else: @@ -268,7 +295,7 @@ def from_jwk(jwk): def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: return hmac.compare_digest(sig, self.sign(msg, key)) @@ -287,7 +314,7 @@ class RSAAlgorithm(Algorithm): def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: Union[AllowedRSAKeys, AnyStr]) -> AllowedRSAKeys: if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key @@ -298,15 +325,17 @@ def prepare_key(self, key): try: if key_bytes.startswith(b"ssh-rsa"): - return load_ssh_public_key(key_bytes) + return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) else: - return load_pem_private_key(key_bytes, password=None) + return cast( + RSAPrivateKey, load_pem_private_key(key_bytes, password=None) + ) except ValueError: - return load_pem_public_key(key_bytes) + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj): - obj = None + def to_jwk(key_obj: AllowedRSAKeys) -> str: + obj: Optional[Dict[str, Any]] = None if hasattr(key_obj, "private_numbers"): # Private key @@ -341,7 +370,7 @@ def to_jwk(key_obj): return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: Union[str, JWKDict]) -> AllowedRSAKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -412,10 +441,10 @@ def from_jwk(jwk): else: raise InvalidKeyError("Not a public or private key") - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) return True @@ -435,7 +464,7 @@ class ECAlgorithm(Algorithm): def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: Union[AllowedECKeys, AnyStr]) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key @@ -449,41 +478,46 @@ def prepare_key(self, key): # the Verifying Key first. try: if key_bytes.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key_bytes) + crypto_key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key_bytes) + crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] except ValueError: - key = load_pem_private_key(key_bytes, password=None) + crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography - if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + if not isinstance( + crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) + ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" ) - return key + return crypto_key - def sign(self, msg, key): + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: return False try: - if isinstance(key, EllipticCurvePrivateKey): - key = key.public_key() - key.verify(der_sig, msg, ECDSA(self.hash_alg())) + public_key = ( + key.public_key() + if isinstance(key, EllipticCurvePrivateKey) + else key + ) + public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: AllowedECKeys) -> str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -502,7 +536,7 @@ def to_jwk(key_obj): else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") - obj = { + obj: Dict[str, Any] = { "kty": "EC", "crv": crv, "x": to_base64url_uint(public_numbers.x).decode(), @@ -518,8 +552,8 @@ def to_jwk(key_obj): @staticmethod def from_jwk( - jwk: Any, - ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]: + jwk: Union[str, JWKDict], + ) -> AllowedECKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -591,7 +625,7 @@ class RSAPSSAlgorithm(RSAAlgorithm): Performs a signature using RSASSA-PSS with MGF1 """ - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign( msg, padding.PSS( @@ -601,7 +635,7 @@ def sign(self, msg, key): self.hash_alg(), ) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify( sig, @@ -623,21 +657,20 @@ class OKPAlgorithm(Algorithm): This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key): + def prepare_key(self, key: Union[AllowedOKPKeys, str, bytes]) -> AllowedOKPKeys: if isinstance(key, (bytes, str)): - if isinstance(key, str): - key = key.encode("utf-8") - str_key = key.decode("utf-8") + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + key_bytes = key.encode("utf-8") if isinstance(key, str) else key - if "-----BEGIN PUBLIC" in str_key: - key = load_pem_public_key(key) - elif "-----BEGIN PRIVATE" in str_key: - key = load_pem_private_key(key, password=None) - elif str_key[0:4] == "ssh-": - key = load_ssh_public_key(key) + if "-----BEGIN PUBLIC" in key_str: + key = load_pem_public_key(key_bytes) # type: ignore[assignment] + elif "-----BEGIN PRIVATE" in key_str: + key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + elif key_str[0:4] == "ssh-": + key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography if not isinstance( @@ -650,7 +683,9 @@ def prepare_key(self, key): return key - def sign(self, msg, key): + def sign( + self, msg: Union[str, bytes], key: Union[Ed25519PrivateKey, Ed448PrivateKey] + ) -> bytes: """ Sign a message ``msg`` using the EdDSA private key ``key`` :param str|bytes msg: Message to sign @@ -658,10 +693,12 @@ def sign(self, msg, key): or :class:`.Ed448PrivateKey` isinstance :return bytes signature: The signature, as bytes """ - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - return key.sign(msg) + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + return key.sign(msg_bytes) - def verify(self, msg, key, sig): + def verify( + self, msg: Union[str, bytes], key: AllowedOKPKeys, sig: Union[str, bytes] + ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -672,18 +709,21 @@ def verify(self, msg, key, sig): :return bool verified: True if signature is valid, False if not. """ try: - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig - if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): - key = key.public_key() - key.verify(sig, msg) + public_key = ( + key.public_key() + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) + else key + ) + public_key.verify(sig_bytes, msg_bytes) return True # If no exception was raised, the signature is valid. - except cryptography.exceptions.InvalidSignature: + except InvalidSignature: return False @staticmethod - def to_jwk(key): + def to_jwk(key: AllowedOKPKeys) -> str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -723,7 +763,7 @@ def to_jwk(key): raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: Union[str, JWKDict]) -> AllowedOKPKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index ade8d688..fa6708cc 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ import binascii import json import warnings -from typing import Any +from typing import TYPE_CHECKING, Any from .algorithms import ( Algorithm, @@ -20,6 +20,9 @@ from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWS: header_typ = "JWT" @@ -100,7 +103,7 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm: def encode( self, payload: bytes, - key: str, + key: AllowedPrivateKeys | str | bytes, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, @@ -169,7 +172,7 @@ def encode( def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -214,7 +217,7 @@ def decode_complete( def decode( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -286,7 +289,7 @@ def _verify_signature( signing_input: bytes, header: dict[str, Any], signature: bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, ) -> None: try: @@ -301,9 +304,9 @@ def _verify_signature( alg_obj = self.get_algorithm_by_name(alg) except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e - key = alg_obj.prepare_key(key) + prepared_key = alg_obj.prepare_key(key) - if not alg_obj.verify(signing_input, key, signature): + if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers: dict[str, Any]) -> None: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index f50b1f49..49d1b488 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any +from typing import TYPE_CHECKING, Any from . import api_jws from .exceptions import ( @@ -19,6 +19,9 @@ ) from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWT: def __init__(self, options: dict[str, Any] | None = None) -> None: @@ -41,7 +44,7 @@ def _get_default_options() -> dict[str, bool | list[str]]: def encode( self, payload: dict[str, Any], - key: str, + key: AllowedPrivateKeys | str | bytes, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, @@ -97,7 +100,7 @@ def _encode_payload( def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -110,7 +113,7 @@ def decode_complete( issuer: str | None = None, leeway: float | timedelta = 0, # kwargs - **kwargs, + **kwargs: Any, ) -> dict[str, Any]: if kwargs: warnings.warn( @@ -182,7 +185,7 @@ def _decode_payload(self, decoded: dict[str, Any]) -> Any: def decode( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -195,7 +198,7 @@ def decode( issuer: str | None = None, leeway: float | timedelta = 0, # kwargs - **kwargs, + **kwargs: Any, ) -> Any: if kwargs: warnings.warn( diff --git a/jwt/utils.py b/jwt/utils.py index 92e3c8be..9d8c4540 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Any, AnyStr +from typing import AnyStr try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -13,7 +13,7 @@ pass -def force_bytes(value: Any) -> bytes: +def force_bytes(value: AnyStr) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): From 16da314831b38061611b92d36c63cf842b1453b8 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:40:11 +0200 Subject: [PATCH 03/20] Use string type annotations where required --- jwt/algorithms.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9ce45f5c..ec85d3bf 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -314,7 +314,7 @@ class RSAAlgorithm(Algorithm): def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: Union[AllowedRSAKeys, AnyStr]) -> AllowedRSAKeys: + def prepare_key(self, key: Union["AllowedRSAKeys", AnyStr]) -> "AllowedRSAKeys": if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key @@ -334,7 +334,7 @@ def prepare_key(self, key: Union[AllowedRSAKeys, AnyStr]) -> AllowedRSAKeys: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj: AllowedRSAKeys) -> str: + def to_jwk(key_obj: "AllowedRSAKeys") -> str: obj: Optional[Dict[str, Any]] = None if hasattr(key_obj, "private_numbers"): @@ -370,7 +370,7 @@ def to_jwk(key_obj: AllowedRSAKeys) -> str: return json.dumps(obj) @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> AllowedRSAKeys: + def from_jwk(jwk: Union[str, JWKDict]) -> "AllowedRSAKeys": try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -464,7 +464,7 @@ class ECAlgorithm(Algorithm): def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: Union[AllowedECKeys, AnyStr]) -> AllowedECKeys: + def prepare_key(self, key: Union["AllowedECKeys", AnyStr]) -> "AllowedECKeys": if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key @@ -499,7 +499,7 @@ def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: + def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: @@ -517,7 +517,7 @@ def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: return False @staticmethod - def to_jwk(key_obj: AllowedECKeys) -> str: + def to_jwk(key_obj: "AllowedECKeys") -> str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -553,7 +553,7 @@ def to_jwk(key_obj: AllowedECKeys) -> str: @staticmethod def from_jwk( jwk: Union[str, JWKDict], - ) -> AllowedECKeys: + ) -> "AllowedECKeys": try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -660,7 +660,7 @@ class OKPAlgorithm(Algorithm): def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key: Union[AllowedOKPKeys, str, bytes]) -> AllowedOKPKeys: + def prepare_key(self, key: Union["AllowedOKPKeys", str, bytes]) -> "AllowedOKPKeys": if isinstance(key, (bytes, str)): key_str = key.decode("utf-8") if isinstance(key, bytes) else key key_bytes = key.encode("utf-8") if isinstance(key, str) else key @@ -697,7 +697,7 @@ def sign( return key.sign(msg_bytes) def verify( - self, msg: Union[str, bytes], key: AllowedOKPKeys, sig: Union[str, bytes] + self, msg: Union[str, bytes], key: "AllowedOKPKeys", sig: Union[str, bytes] ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -723,7 +723,7 @@ def verify( return False @staticmethod - def to_jwk(key: AllowedOKPKeys) -> str: + def to_jwk(key: "AllowedOKPKeys") -> str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -763,7 +763,7 @@ def to_jwk(key: AllowedOKPKeys) -> str: raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> AllowedOKPKeys: + def from_jwk(jwk: Union[str, JWKDict]) -> "AllowedOKPKeys": try: if isinstance(jwk, str): obj = json.loads(jwk) From 819a86a225e9f6356c7ac441b5ea0c625d35d76e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Apr 2023 15:40:28 +0000 Subject: [PATCH 04/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jwt/algorithms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index ec85d3bf..0299ff97 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -660,7 +660,9 @@ class OKPAlgorithm(Algorithm): def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key: Union["AllowedOKPKeys", str, bytes]) -> "AllowedOKPKeys": + def prepare_key( + self, key: Union["AllowedOKPKeys", str, bytes] + ) -> "AllowedOKPKeys": if isinstance(key, (bytes, str)): key_str = key.decode("utf-8") if isinstance(key, bytes) else key key_bytes = key.encode("utf-8") if isinstance(key, str) else key From 4e567752298087dc4735e7393029487040e92e18 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:51:30 +0200 Subject: [PATCH 05/20] Remove outdated comment --- jwt/algorithms.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 0299ff97..ad473b8d 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -172,10 +172,6 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes: else: return bytes(hash_alg(bytestr).digest()) - # TODO: all key-related `Any`s in this class should optimally be made - # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 - # that may still be poorly supported. - @abstractmethod def prepare_key(self, key: Any) -> Any: """ From 67c57ecf19cd3c2949ce6d1a2072b1925b738b8f Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:55:30 +0200 Subject: [PATCH 06/20] Ignore `if TYPE_CHECKING:` lines in coverage --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e7920ffb..5df78b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ source = ["jwt", ".tox/*/site-packages"] [tool.coverage.report] show_missing = true +exclude_lines = ["if TYPE_CHECKING:"] [tool.isort] profile = "black" From 8139b3dc2ad34b35bbde79252927e5c97840ed66 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 10 Apr 2023 18:10:19 +0200 Subject: [PATCH 07/20] Remove duplicate test --- tests/test_algorithms.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 8aa9ad72..0106ad0c 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -767,12 +767,6 @@ def test_okp_ed25519_should_reject_non_string_key(self): with open(key_path("testkey_ed25519.pub")) as keyfile: algo.prepare_key(keyfile.read()) - def test_okp_ed25519_should_accept_unicode_key(self): - algo = OKPAlgorithm() - - with open(key_path("testkey_ed25519")) as ec_key: - algo.prepare_key(ec_key.read()) - def test_okp_ed25519_sign_should_generate_correct_signature_value(self): algo = OKPAlgorithm() From 636820164a11cc0253441d99e256ba07a3009dd0 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 10 Apr 2023 20:57:48 +0200 Subject: [PATCH 08/20] Fix mypy errors --- jwt/utils.py | 8 ++++---- tests/test_algorithms.py | 6 +++--- tests/test_utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/jwt/utils.py b/jwt/utils.py index 9d8c4540..81c5ee41 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import AnyStr +from typing import Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -13,7 +13,7 @@ pass -def force_bytes(value: AnyStr) -> bytes: +def force_bytes(value: Union[bytes, str]) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): @@ -22,7 +22,7 @@ def force_bytes(value: AnyStr) -> bytes: raise TypeError("Expected a string value") -def base64url_decode(input: AnyStr) -> bytes: +def base64url_decode(input: Union[bytes, str]) -> bytes: input_bytes = force_bytes(input) rem = len(input_bytes) % 4 @@ -49,7 +49,7 @@ def to_base64url_uint(val: int) -> bytes: return base64url_encode(int_bytes) -def from_base64url_uint(val: AnyStr) -> int: +def from_base64url_uint(val: Union[bytes, str]) -> int: data = base64url_decode(force_bytes(val)) return int.from_bytes(data, byteorder="big") diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0106ad0c..dfe09139 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -37,7 +37,7 @@ def test_hmac_should_reject_nonstring_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) with pytest.raises(TypeError) as context: - algo.prepare_key(object()) + algo.prepare_key(object()) # type: ignore[type-var] exception = context.value assert str(exception) == "Expected a string value" @@ -861,7 +861,7 @@ def test_okp_ed25519_jwk_fails_on_invalid_json(self): # Invalid instance type with pytest.raises(InvalidKeyError): - algo.from_jwk(123) + algo.from_jwk(123) # type: ignore[arg-type] # Invalid JSON with pytest.raises(InvalidKeyError): @@ -970,7 +970,7 @@ def test_okp_ed448_jwk_fails_on_invalid_json(self): # Invalid instance type with pytest.raises(InvalidKeyError): - algo.from_jwk(123) + algo.from_jwk(123) # type: ignore[arg-type] # Invalid JSON with pytest.raises(InvalidKeyError): diff --git a/tests/test_utils.py b/tests/test_utils.py index a089f860..122dcb4e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,4 +36,4 @@ def test_from_base64url_uint(inputval, expected): def test_force_bytes_raises_error_on_invalid_object(): with pytest.raises(TypeError): - force_bytes({}) + force_bytes({}) # type: ignore[arg-type] From 2e3385115eef4ae81fab77a8719e8c34e8c51e98 Mon Sep 17 00:00:00 2001 From: Asif Saif Uddin Date: Tue, 11 Apr 2023 11:30:36 +0600 Subject: [PATCH 09/20] Update algorithms.py --- jwt/algorithms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index ad473b8d..46a17dd7 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,3 +1,4 @@ +from __future__ import annotations import hashlib import hmac import json From b5a73f8a3a1265f2796d7737b80e518bfc97402d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 05:30:47 +0000 Subject: [PATCH 10/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jwt/algorithms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 46a17dd7..87637a9f 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,4 +1,5 @@ from __future__ import annotations + import hashlib import hmac import json From 11cc10d7f46c7ee220f7137dfb6cee88bf55e5e1 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 11 Apr 2023 07:54:56 +0200 Subject: [PATCH 11/20] Fully switch to modern annotations --- jwt/algorithms.py | 83 +++++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 49 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 87637a9f..e4020b79 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -7,13 +7,8 @@ from typing import ( TYPE_CHECKING, Any, - AnyStr, ClassVar, - Dict, NoReturn, - Optional, - Type, - Union, cast, ) @@ -83,18 +78,12 @@ if TYPE_CHECKING: # Type aliases for convenience in algorithms method signatures - AllowedRSAKeys = Union[RSAPrivateKey, RSAPublicKey] - AllowedECKeys = Union[EllipticCurvePrivateKey, EllipticCurvePublicKey] - AllowedOKPKeys = Union[ - Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey - ] - AllowedKeys = Union[AllowedRSAKeys, AllowedECKeys, AllowedOKPKeys] - AllowedPrivateKeys = Union[ - RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey - ] - AllowedPublicKeys = Union[ - RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey - ] + AllowedRSAKeys = RSAPrivateKey | RSAPublicKey + AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey + AllowedOKPKeys = Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys + AllowedPrivateKeys = RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + AllowedPublicKeys = RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey requires_cryptography = { @@ -113,7 +102,7 @@ } -def get_default_algorithms() -> Dict[str, "Algorithm"]: +def get_default_algorithms() -> dict[str, Algorithm]: """ Returns the algorithms that are implemented by the library. """ @@ -204,7 +193,7 @@ def to_jwk(key_obj) -> str: @staticmethod @abstractmethod - def from_jwk(jwk: Union[str, JWKDict]) -> Any: + def from_jwk(jwk: str | JWKDict) -> Any: """ Deserializes a given key from JWK back into a key object """ @@ -216,7 +205,7 @@ class NoneAlgorithm(Algorithm): operations are required. """ - def prepare_key(self, key: Optional[str]) -> None: + def prepare_key(self, key: str | None) -> None: if key == "": key = None @@ -236,7 +225,7 @@ def to_jwk(key_obj: Any) -> NoReturn: raise NotImplementedError() @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> NoReturn: + def from_jwk(jwk: str | JWKDict) -> NoReturn: raise NotImplementedError() @@ -253,7 +242,7 @@ class HMACAlgorithm(Algorithm): def __init__(self, hash_alg: HashlibHash) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: AnyStr) -> bytes: + def prepare_key(self, key: str | bytes) -> bytes: key_bytes = force_bytes(key) if is_pem_format(key_bytes) or is_ssh_key(key_bytes): @@ -265,7 +254,7 @@ def prepare_key(self, key: AnyStr) -> bytes: return key_bytes @staticmethod - def to_jwk(key_obj: AnyStr) -> str: + def to_jwk(key_obj: str | bytes) -> str: return json.dumps( { "k": base64url_encode(force_bytes(key_obj)).decode(), @@ -274,7 +263,7 @@ def to_jwk(key_obj: AnyStr) -> str: ) @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> bytes: + def from_jwk(jwk: str | JWKDict) -> bytes: try: if isinstance(jwk, str): obj: JWKDict = json.loads(jwk) @@ -305,14 +294,14 @@ class RSAAlgorithm(Algorithm): RSASSA-PKCS-v1_5 and the specified hash function. """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: Union["AllowedRSAKeys", AnyStr]) -> "AllowedRSAKeys": + def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key @@ -332,8 +321,8 @@ def prepare_key(self, key: Union["AllowedRSAKeys", AnyStr]) -> "AllowedRSAKeys": return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj: "AllowedRSAKeys") -> str: - obj: Optional[Dict[str, Any]] = None + def to_jwk(key_obj: AllowedRSAKeys) -> str: + obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): # Private key @@ -368,7 +357,7 @@ def to_jwk(key_obj: "AllowedRSAKeys") -> str: return json.dumps(obj) @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> "AllowedRSAKeys": + def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -455,14 +444,14 @@ class ECAlgorithm(Algorithm): ECDSA and the specified hash function """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: Union["AllowedECKeys", AnyStr]) -> "AllowedECKeys": + def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key @@ -515,7 +504,7 @@ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: return False @staticmethod - def to_jwk(key_obj: "AllowedECKeys") -> str: + def to_jwk(key_obj: AllowedECKeys) -> str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -534,7 +523,7 @@ def to_jwk(key_obj: "AllowedECKeys") -> str: else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") - obj: Dict[str, Any] = { + obj: dict[str, Any] = { "kty": "EC", "crv": crv, "x": to_base64url_uint(public_numbers.x).decode(), @@ -549,9 +538,7 @@ def to_jwk(key_obj: "AllowedECKeys") -> str: return json.dumps(obj) @staticmethod - def from_jwk( - jwk: Union[str, JWKDict], - ) -> "AllowedECKeys": + def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -658,9 +645,7 @@ class OKPAlgorithm(Algorithm): def __init__(self, **kwargs: Any) -> None: pass - def prepare_key( - self, key: Union["AllowedOKPKeys", str, bytes] - ) -> "AllowedOKPKeys": + def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: if isinstance(key, (bytes, str)): key_str = key.decode("utf-8") if isinstance(key, bytes) else key key_bytes = key.encode("utf-8") if isinstance(key, str) else key @@ -684,7 +669,7 @@ def prepare_key( return key def sign( - self, msg: Union[str, bytes], key: Union[Ed25519PrivateKey, Ed448PrivateKey] + self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey ) -> bytes: """ Sign a message ``msg`` using the EdDSA private key ``key`` @@ -697,7 +682,7 @@ def sign( return key.sign(msg_bytes) def verify( - self, msg: Union[str, bytes], key: "AllowedOKPKeys", sig: Union[str, bytes] + self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -723,7 +708,7 @@ def verify( return False @staticmethod - def to_jwk(key: "AllowedOKPKeys") -> str: + def to_jwk(key: AllowedOKPKeys) -> str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -763,7 +748,7 @@ def to_jwk(key: "AllowedOKPKeys") -> str: raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk: Union[str, JWKDict]) -> "AllowedOKPKeys": + def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) From 7a807b31ff662ad4eed5232ad51797e9b61fcd65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 05:55:10 +0000 Subject: [PATCH 12/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jwt/algorithms.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index e4020b79..4bcf81b6 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -4,13 +4,7 @@ import hmac import json from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - NoReturn, - cast, -) +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -80,10 +74,16 @@ # Type aliases for convenience in algorithms method signatures AllowedRSAKeys = RSAPrivateKey | RSAPublicKey AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey - AllowedOKPKeys = Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + AllowedOKPKeys = ( + Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + ) AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys - AllowedPrivateKeys = RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey - AllowedPublicKeys = RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + AllowedPrivateKeys = ( + RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + ) + AllowedPublicKeys = ( + RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + ) requires_cryptography = { @@ -321,7 +321,7 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj: AllowedRSAKeys) -> str: + def to_jwk(key_obj: AllowedRSAKeys) -> str: obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): From a10f1fc573e9d96cd9338787d1e53858072cf528 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Thu, 13 Apr 2023 10:59:22 +0200 Subject: [PATCH 13/20] Update `pre-commit` mypy config --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 953360de..cd469ede 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,3 +39,4 @@ repos: rev: "v1.1.1" hooks: - id: mypy + additional_dependencies: [cryptography>=3.4.0] From 33dac85e63a380ce6914a4c4db2ee882d547deb1 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Thu, 13 Apr 2023 11:13:51 +0200 Subject: [PATCH 14/20] Use Python 3.11 for mypy --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd469ede..9197f882 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,4 +39,5 @@ repos: rev: "v1.1.1" hooks: - id: mypy + args: [--python-version 3.11, --ignore-missing-imports] additional_dependencies: [cryptography>=3.4.0] From 57e46bf522ac34c6742db3176c030061b998b43e Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Thu, 13 Apr 2023 11:18:11 +0200 Subject: [PATCH 15/20] Update mypy Python version in `pyproject.toml` --- .pre-commit-config.yaml | 1 - pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9197f882..cd469ede 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,5 +39,4 @@ repos: rev: "v1.1.1" hooks: - id: mypy - args: [--python-version 3.11, --ignore-missing-imports] additional_dependencies: [cryptography>=3.4.0] diff --git a/pyproject.toml b/pyproject.toml index 5df78b4b..40dfb7bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ atomic = true combine_as_imports = true [tool.mypy] -python_version = 3.7 +python_version = 3.11 ignore_missing_imports = true warn_unused_ignores = true no_implicit_optional = true From db9e9a8ddf76e691e8be2124e16a8d95a48eeab6 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Thu, 13 Apr 2023 11:30:29 +0200 Subject: [PATCH 16/20] Few tests mypy fixes --- tests/test_algorithms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index dfe09139..a8a54dff 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -37,7 +37,7 @@ def test_hmac_should_reject_nonstring_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) with pytest.raises(TypeError) as context: - algo.prepare_key(object()) # type: ignore[type-var] + algo.prepare_key(object()) # type: ignore[arg-type] exception = context.value assert str(exception) == "Expected a string value" @@ -112,7 +112,7 @@ def test_rsa_should_reject_non_string_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with pytest.raises(TypeError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] @crypto_required def test_rsa_verify_should_return_false_if_signature_invalid(self): @@ -284,7 +284,7 @@ def test_ec_to_jwk_raises_exception_on_invalid_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] @crypto_required def test_ec_to_jwk_with_valid_curves(self): @@ -505,7 +505,7 @@ def test_rsa_to_jwk_raises_exception_on_invalid_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] @crypto_required def test_rsa_from_jwk_raises_exception_on_invalid_key(self): @@ -520,7 +520,7 @@ def test_ec_should_reject_non_string_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with pytest.raises(TypeError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] @crypto_required def test_ec_should_accept_pem_private_key_bytes(self): @@ -926,7 +926,7 @@ def test_okp_to_jwk_raises_exception_on_invalid_key(self): algo = OKPAlgorithm() with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] def test_okp_ed448_jwk_private_key_should_parse_and_verify(self): algo = OKPAlgorithm() From b98c03f383a1eae6e7a391f6e9b7ad78e15931d3 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 15 Apr 2023 14:02:33 +0200 Subject: [PATCH 17/20] fix mypy errors on tests --- tests/test_algorithms.py | 95 ++++++++++++++++++++++++---------------- tests/test_api_jws.py | 8 ++-- 2 files changed, 62 insertions(+), 41 deletions(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index a8a54dff..aa62ffc2 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,5 +1,6 @@ import base64 import json +from typing import TYPE_CHECKING, cast import pytest @@ -13,6 +14,24 @@ if has_crypto: from jwt.algorithms import ECAlgorithm, OKPAlgorithm, RSAAlgorithm, RSAPSSAlgorithm +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePrivateKey, + EllipticCurvePublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed448 import ( + Ed448PrivateKey, + Ed448PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, + ) + class TestAlgorithms: def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): @@ -132,7 +151,7 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): sig += b"123" # Signature is now invalid with open(key_path("testkey_rsa.pub")) as keyfile: - pub_key = algo.prepare_key(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(message, pub_key, sig) assert not result @@ -149,10 +168,10 @@ def test_ec_jwk_public_and_private_keys_should_parse_and_verify(self): algo = ECAlgorithm(hash) with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(EllipticCurvePublicKey, algo.from_jwk(keyfile.read())) with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(EllipticCurvePrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -223,9 +242,9 @@ def test_ec_private_key_to_jwk_works_with_from_jwk(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.priv")) as ec_key: - orig_key = algo.prepare_key(ec_key.read()) + orig_key = cast(EllipticCurvePrivateKey, algo.prepare_key(ec_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(EllipticCurvePrivateKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.private_numbers() == orig_key.private_numbers() assert ( parsed_key.private_numbers().public_numbers @@ -237,9 +256,9 @@ def test_ec_public_key_to_jwk_works_with_from_jwk(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.pub")) as ec_key: - orig_key = algo.prepare_key(ec_key.read()) + orig_key = cast(EllipticCurvePublicKey, algo.prepare_key(ec_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(EllipticCurvePublicKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.public_numbers() == orig_key.public_numbers() @crypto_required @@ -320,10 +339,10 @@ def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("jwk_rsa_pub.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_rsa_key.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(RSAPrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -333,9 +352,9 @@ def test_rsa_private_key_to_jwk_works_with_from_jwk(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.priv")) as rsa_key: - orig_key = algo.prepare_key(rsa_key.read()) + orig_key = cast(RSAPrivateKey, algo.prepare_key(rsa_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(RSAPrivateKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.private_numbers() == orig_key.private_numbers() assert ( parsed_key.private_numbers().public_numbers @@ -347,9 +366,9 @@ def test_rsa_public_key_to_jwk_works_with_from_jwk(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.pub")) as rsa_key: - orig_key = algo.prepare_key(rsa_key.read()) + orig_key = cast(RSAPublicKey, algo.prepare_key(rsa_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(RSAPublicKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.public_numbers() == orig_key.public_numbers() @crypto_required @@ -380,14 +399,16 @@ def test_rsa_jwk_private_key_can_recover_prime_factors(self): with open(key_path("jwk_rsa_key.json")) as keyfile: keybytes = keyfile.read() - control_key = algo.from_jwk(keybytes).private_numbers() + control_key = cast(RSAPrivateKey, algo.from_jwk(keybytes)).private_numbers() keydata = json.loads(keybytes) delete_these = ["p", "q", "dp", "dq", "qi"] for field in delete_these: del keydata[field] - parsed_key = algo.from_jwk(json.dumps(keydata)).private_numbers() + parsed_key = cast( + RSAPrivateKey, algo.from_jwk(json.dumps(keydata)) + ).private_numbers() assert control_key.d == parsed_key.d assert control_key.p == parsed_key.p @@ -590,11 +611,11 @@ def test_rsa_pss_sign_then_verify_should_return_true(self): message = b"Hello World!" with open(key_path("testkey_rsa.priv")) as keyfile: - priv_key = algo.prepare_key(keyfile.read()) + priv_key = cast(RSAPrivateKey, algo.prepare_key(keyfile.read())) sig = algo.sign(message, priv_key) with open(key_path("testkey_rsa.pub")) as keyfile: - pub_key = algo.prepare_key(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(message, pub_key, sig) assert result @@ -617,7 +638,7 @@ def test_rsa_pss_verify_should_return_false_if_signature_invalid(self): jwt_sig += b"123" # Signature is now invalid with open(key_path("testkey_rsa.pub")) as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + jwt_pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(jwt_message, jwt_pub_key, jwt_sig) assert not result @@ -678,7 +699,7 @@ def test_rsa_verify_should_return_true_for_test_vector(self): ) algo = RSAAlgorithm(RSAAlgorithm.SHA256) - key = algo.prepare_key(load_rsa_pub_key()) + key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key())) result = algo.verify(signing_input, key, signature) assert result @@ -709,7 +730,7 @@ def test_rsapss_verify_should_return_true_for_test_vector(self): ) algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384) - key = algo.prepare_key(load_rsa_pub_key()) + key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key())) result = algo.verify(signing_input, key, signature) assert result @@ -759,7 +780,7 @@ def test_okp_ed25519_should_reject_non_string_key(self): algo = OKPAlgorithm() with pytest.raises(InvalidKeyError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] with open(key_path("testkey_ed25519")) as keyfile: algo.prepare_key(keyfile.read()) @@ -775,10 +796,10 @@ def test_okp_ed25519_sign_should_generate_correct_signature_value(self): expected_sig = base64.b64decode(self.hello_world_sig) with open(key_path("testkey_ed25519")) as keyfile: - jwt_key = algo.prepare_key(keyfile.read()) + jwt_key = cast(Ed25519PrivateKey, algo.prepare_key(keyfile.read())) with open(key_path("testkey_ed25519.pub")) as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + jwt_pub_key = cast(Ed25519PublicKey, algo.prepare_key(keyfile.read())) algo.sign(jwt_message, jwt_key) result = algo.verify(jwt_message, jwt_pub_key, expected_sig) @@ -823,7 +844,7 @@ def test_okp_ed25519_jwk_private_key_should_parse_and_verify(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key.public_key(), signature) @@ -834,7 +855,7 @@ def test_okp_ed25519_jwk_private_key_should_parse_and_verify_with_private_key_as algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) @@ -843,10 +864,10 @@ def test_okp_ed25519_jwk_public_key_should_parse_and_verify(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -907,15 +928,15 @@ def test_okp_ed25519_to_jwk_works_with_from_jwk(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - priv_key_1 = algo.from_jwk(keyfile.read()) + priv_key_1 = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile: - pub_key_1 = algo.from_jwk(keyfile.read()) + pub_key_1 = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read())) pub = algo.to_jwk(pub_key_1) pub_key_2 = algo.from_jwk(pub) pri = algo.to_jwk(priv_key_1) - priv_key_2 = algo.from_jwk(pri) + priv_key_2 = cast(Ed25519PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) signature_2 = algo.sign(b"Hello World!", priv_key_2) @@ -932,7 +953,7 @@ def test_okp_ed448_jwk_private_key_should_parse_and_verify(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key.public_key(), signature) @@ -943,7 +964,7 @@ def test_okp_ed448_jwk_private_key_should_parse_and_verify_with_private_key_as_i algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) @@ -952,10 +973,10 @@ def test_okp_ed448_jwk_public_key_should_parse_and_verify(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(Ed448PublicKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -1016,15 +1037,15 @@ def test_okp_ed448_to_jwk_works_with_from_jwk(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - priv_key_1 = algo.from_jwk(keyfile.read()) + priv_key_1 = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile: - pub_key_1 = algo.from_jwk(keyfile.read()) + pub_key_1 = cast(Ed448PublicKey, algo.from_jwk(keyfile.read())) pub = algo.to_jwk(pub_key_1) pub_key_2 = algo.from_jwk(pub) pri = algo.to_jwk(priv_key_1) - priv_key_2 = algo.from_jwk(pri) + priv_key_2 = cast(Ed448PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) signature_2 = algo.sign(b"Hello World!", priv_key_2) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index b961dc2e..33857161 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -533,11 +533,11 @@ def test_encode_decode_rsa_related_algorithms(self, jws, payload, algo): # string-formatted key with open(key_path("testkey_rsa.priv")) as rsa_priv_file: - priv_rsakey = rsa_priv_file.read() + priv_rsakey = rsa_priv_file.read() # type: ignore[assignment] jws_message = jws.encode(payload, priv_rsakey, algorithm=algo) with open(key_path("testkey_rsa.pub")) as rsa_pub_file: - pub_rsakey = rsa_pub_file.read() + pub_rsakey = rsa_pub_file.read() # type: ignore[assignment] jws.decode(jws_message, pub_rsakey, algorithms=[algo]) def test_rsa_related_algorithms(self, jws): @@ -582,11 +582,11 @@ def test_encode_decode_ecdsa_related_algorithms(self, jws, payload, algo): # string-formatted key with open(key_path("testkey_ec.priv")) as ec_priv_file: - priv_eckey = ec_priv_file.read() + priv_eckey = ec_priv_file.read() # type: ignore[assignment] jws_message = jws.encode(payload, priv_eckey, algorithm=algo) with open(key_path("testkey_ec.pub")) as ec_pub_file: - pub_eckey = ec_pub_file.read() + pub_eckey = ec_pub_file.read() # type: ignore[assignment] jws.decode(jws_message, pub_eckey, algorithms=[algo]) def test_ecdsa_related_algorithms(self, jws): From 476ca58b6098c386d4a658abb193727d233acb9a Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 15 Apr 2023 14:16:29 +0200 Subject: [PATCH 18/20] Fix key imports --- tests/test_algorithms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index aa62ffc2..b937eade 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -12,9 +12,6 @@ from .utils import crypto_required, key_path if has_crypto: - from jwt.algorithms import ECAlgorithm, OKPAlgorithm, RSAAlgorithm, RSAPSSAlgorithm - -if TYPE_CHECKING: from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePrivateKey, EllipticCurvePublicKey, @@ -32,6 +29,8 @@ RSAPublicKey, ) + from jwt.algorithms import ECAlgorithm, OKPAlgorithm, RSAAlgorithm, RSAPSSAlgorithm + class TestAlgorithms: def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): From 8eb7610b214aa09463944a369eade84d4b667bc8 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 15 Apr 2023 14:20:05 +0200 Subject: [PATCH 19/20] Remove unused import --- tests/test_algorithms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index b937eade..7a90376b 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,6 +1,6 @@ import base64 import json -from typing import TYPE_CHECKING, cast +from typing import cast import pytest From 27e2b45ed4f8df85d416ffac42f3ba208c565d38 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 15 Apr 2023 18:03:26 +0200 Subject: [PATCH 20/20] Fix randomly failing test --- tests/test_api_jwt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 3c08ea59..0d534446 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -365,8 +365,8 @@ def test_decode_with_expiration_with_leeway(self, jwt, payload): secret = "secret" jwt_message = jwt.encode(payload, secret) - # With 3 seconds leeway, should be ok - for leeway in (3, timedelta(seconds=3)): + # With 5 seconds leeway, should be ok + for leeway in (5, timedelta(seconds=5)): decoded = jwt.decode( jwt_message, secret, leeway=leeway, algorithms=["HS256"] )