diff --git a/pyproject.toml b/pyproject.toml index c4469c9..f05bcef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ exclude_lines = [ ] [tool.mypy] +strict = true python_version = "3.8" files = ["src/joserfc"] show_error_codes = true diff --git a/src/joserfc/jws.py b/src/joserfc/jws.py index 198744b..f8d7fc3 100644 --- a/src/joserfc/jws.py +++ b/src/joserfc/jws.py @@ -227,7 +227,7 @@ def serialize_json( if registry is None: registry = construct_registry(algorithms) - def find_key(obj: Any): + def find_key(obj: Any) -> Key: return guess_key(private_key, obj, True) _payload = to_bytes(payload) @@ -271,7 +271,7 @@ def deserialize_json( if registry is None: registry = construct_registry(algorithms) - def find_key(obj: Any): + def find_key(obj: Any) -> Key: return guess_key(public_key, obj) if "signatures" in value: diff --git a/src/joserfc/rfc7516/json.py b/src/joserfc/rfc7516/json.py index b2a7326..4c7e410 100644 --- a/src/joserfc/rfc7516/json.py +++ b/src/joserfc/rfc7516/json.py @@ -93,7 +93,8 @@ def extract_flattened_json(data: FlattenedJSONSerialization) -> FlattenedJSONEnc def __extract_segments( - data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization]): # type: ignore[no-untyped-def] + data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization] +) -> t.Tuple[t.Dict[str, bytes], t.Dict[str, bytes], t.Optional[bytes]]: base64_segments: t.Dict[str, bytes] = { "iv": to_bytes(data["iv"]), "ciphertext": to_bytes(data["ciphertext"]), diff --git a/src/joserfc/rfc7516/models.py b/src/joserfc/rfc7516/models.py index 57181e5..ed44962 100644 --- a/src/joserfc/rfc7516/models.py +++ b/src/joserfc/rfc7516/models.py @@ -1,9 +1,10 @@ +from __future__ import annotations import os import typing as t from abc import ABCMeta, abstractmethod from ..registry import Header, HeaderRegistryDict from ..errors import InvalidKeyTypeError, InvalidKeyLengthError -from .._keys import Key, ECKey +from .._keys import Key, ECKey, OctKey KeyType = t.TypeVar("KeyType") @@ -12,8 +13,8 @@ class Recipient(t.Generic[KeyType]): def __init__( self, parent: t.Union["CompactEncryption", "GeneralJSONEncryption", "FlattenedJSONEncryption"], - header: t.Optional[Header] = None, - recipient_key: t.Optional[KeyType] = None): + header: Header | None = None, + recipient_key: KeyType | None = None): self.__parent = parent self.header = header self.recipient_key = recipient_key @@ -30,7 +31,7 @@ def headers(self) -> Header: rv.update(self.header) return rv - def add_header(self, k: str, v: t.Any): + def add_header(self, k: str, v: t.Any) -> None: if isinstance(self.__parent, CompactEncryption): self.__parent.protected.update({k: v}) elif self.header: @@ -38,7 +39,7 @@ def add_header(self, k: str, v: t.Any): else: self.header = {k: v} - def set_kid(self, kid: str): + def set_kid(self, kid: str) -> None: self.add_header("kid", kid) @@ -46,19 +47,19 @@ class CompactEncryption: """An object to represent the JWE Compact Serialization. It is usually returned by ``decrypt_compact`` method. """ - def __init__(self, protected: Header, plaintext: t.Optional[bytes] = None): + def __init__(self, protected: Header, plaintext: bytes | None = None): #: protected header in dict self.protected = protected #: the plaintext in bytes self.plaintext = plaintext - self.recipient: t.Optional[Recipient] = None + self.recipient: Recipient[t.Any] | None = None self.bytes_segments: t.Dict[str, bytes] = {} # store the decoded segments self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments def headers(self) -> Header: return self.protected - def attach_recipient(self, key: Key, header: t.Optional[Header] = None): + def attach_recipient(self, key: Key, header: Header | None = None) -> None: """Add a recipient to the JWE Compact Serialization. Please add a key that comply with the given "alg" value. @@ -71,7 +72,7 @@ def attach_recipient(self, key: Key, header: t.Optional[Header] = None): self.recipient = recipient @property - def recipients(self) -> t.List[Recipient]: + def recipients(self) -> list[Recipient[t.Any]]: if self.recipient is not None: return [self.recipient] return [] @@ -89,14 +90,14 @@ class BaseJSONEncryption(metaclass=ABCMeta): #: an optional additional authenticated data aad: t.Optional[bytes] #: a list of recipients - recipients: t.List[Recipient] + recipients: t.List[Recipient[t.Any]] def __init__( self, protected: Header, - plaintext: t.Optional[bytes] = None, - unprotected: t.Optional[Header] = None, - aad: t.Optional[bytes] = None): + plaintext: bytes | None = None, + unprotected: Header | None = None, + aad: bytes | None = None): self.protected = protected self.plaintext = plaintext self.unprotected = unprotected @@ -106,7 +107,7 @@ def __init__( self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments @abstractmethod - def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None): + def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None: """Add a recipient to the JWE JSON Serialization. Please add a key that comply with the "alg" to this recipient. @@ -131,7 +132,7 @@ class GeneralJSONEncryption(BaseJSONEncryption): """ flattened = False - def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None): + def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None: recipient = Recipient(self, header, key) self.recipients.append(recipient) @@ -152,7 +153,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption): """ flattened = True - def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None): + def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None: self.recipients = [Recipient(self, header, key)] @@ -178,7 +179,7 @@ def check_iv(self, iv: bytes) -> bytes: return iv @abstractmethod - def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> t.Tuple[bytes, bytes]: + def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> tuple[bytes, bytes]: pass @abstractmethod @@ -216,11 +217,11 @@ class KeyManagement: def direct_mode(self) -> bool: return self.key_size is None - def check_key_type(self, key: Key): + def check_key_type(self, key: Key) -> None: if key.key_type not in self.key_types: raise InvalidKeyTypeError() - def prepare_recipient_header(self, recipient: Recipient): + def prepare_recipient_header(self, recipient: Recipient[t.Any]) -> None: raise NotImplementedError() @@ -228,7 +229,7 @@ class JWEDirectEncryption(KeyManagement, metaclass=ABCMeta): key_types = ["oct"] @abstractmethod - def compute_cek(self, size: int, recipient: Recipient) -> bytes: + def compute_cek(self, size: int, recipient: Recipient[OctKey]) -> bytes: pass @@ -238,11 +239,11 @@ def direct_mode(self) -> bool: return False @abstractmethod - def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes: + def encrypt_cek(self, cek: bytes, recipient: Recipient[t.Any]) -> bytes: pass @abstractmethod - def decrypt_cek(self, recipient: Recipient) -> bytes: + def decrypt_cek(self, recipient: Recipient[t.Any]) -> bytes: pass @@ -254,7 +255,7 @@ class JWEKeyWrapping(KeyManagement, metaclass=ABCMeta): def direct_mode(self) -> bool: return False - def check_op_key(self, op_key: bytes): + def check_op_key(self, op_key: bytes) -> None: if len(op_key) * 8 != self.key_size: raise InvalidKeyLengthError(f"A key of size {self.key_size} bits MUST be used") @@ -267,11 +268,11 @@ def unwrap_cek(self, ek: bytes, key: bytes) -> bytes: pass @abstractmethod - def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes: + def encrypt_cek(self, cek: bytes, recipient: Recipient[OctKey]) -> bytes: pass @abstractmethod - def decrypt_cek(self, recipient: Recipient) -> bytes: + def decrypt_cek(self, recipient: Recipient[OctKey]) -> bytes: pass @@ -280,7 +281,7 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta): tag_aware: bool = False key_wrapping: t.Optional[JWEKeyWrapping] - def prepare_ephemeral_key(self, recipient: Recipient[ECKey]): + def prepare_ephemeral_key(self, recipient: Recipient[ECKey]) -> None: recipient_key = recipient.recipient_key assert recipient_key is not None self.check_key_type(recipient_key) diff --git a/src/joserfc/rfc7516/registry.py b/src/joserfc/rfc7516/registry.py index 92e6b1b..cf54e0e 100644 --- a/src/joserfc/rfc7516/registry.py +++ b/src/joserfc/rfc7516/registry.py @@ -51,12 +51,12 @@ def __init__( self.strict_check_header = strict_check_header @classmethod - def register(cls, model: JWEAlgorithm): + def register(cls, model: JWEAlgorithm) -> None: cls.algorithms[model.algorithm_location][model.name] = model # type: ignore if model.recommended: cls.recommended.append(model.name) - def check_header(self, header: Header, check_more=False): + def check_header(self, header: Header, check_more: bool = False) -> None: """Check and validate the fields in header part of a JWS object.""" check_crit_header(header) validate_registry_header(self.header_registry, header) @@ -77,24 +77,29 @@ def get_alg(self, name: str) -> JWEAlgModel: :param name: value of the ``alg``, e.g. ``ECDH-ES``, ``A128KW`` """ - return self._get_algorithm("alg", name) + registry = self.algorithms["alg"] + self._check_algorithm(name, registry) + return registry[name] def get_enc(self, name: str) -> JWEEncModel: """Get the allowed ("enc") algorithm instance of the given name. :param name: value of the ``enc``, e.g. ``A128CBC-HS256``, ``A128GCM`` """ - return self._get_algorithm("enc", name) + registry = self.algorithms["enc"] + self._check_algorithm(name, registry) + return registry[name] def get_zip(self, name: str) -> JWEZipModel: """Get the allowed ("zip") algorithm instance of the given name. :param name: value of the ``zip``, e.g. ``DEF`` """ - return self._get_algorithm("zip", name) + registry = self.algorithms["zip"] + self._check_algorithm(name, registry) + return registry[name] - def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str): - registry: t.Dict[str, JWEAlgorithm] = self.algorithms[location] # type: ignore + def _check_algorithm(self, name: str, registry: dict[str, t.Any]) -> None: if name not in registry: raise ValueError(f'Algorithm of "{name}" is not supported') @@ -105,7 +110,6 @@ def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str): if name not in allowed: raise ValueError(f'Algorithm of "{name}" is not allowed') - return registry[name] default_registry = JWERegistry() diff --git a/src/joserfc/rfc7517/models.py b/src/joserfc/rfc7517/models.py index 0643c69..46cce47 100644 --- a/src/joserfc/rfc7517/models.py +++ b/src/joserfc/rfc7517/models.py @@ -1,5 +1,6 @@ +from __future__ import annotations import typing as t -from typing import overload +from collections.abc import KeysView from abc import ABCMeta, abstractmethod from .types import DictKey, AnyKey, KeyParameters from ..registry import ( @@ -82,7 +83,7 @@ class BaseKey(t.Generic[NativePrivateKey, NativePublicKey], metaclass=ABCMeta): def __init__( self, - raw_value: t.Union[NativePrivateKey, NativePublicKey], + raw_value: NativePrivateKey | NativePublicKey, original_value: t.Any, parameters: t.Optional[KeyParameters] = None): self._raw_value = raw_value @@ -97,13 +98,13 @@ def __init__( self.validate_dict_key(data) self._dict_value = data - def keys(self): + def keys(self) -> KeysView[str]: return self.dict_value.keys() - def __getitem__(self, k: str): + def __getitem__(self, k: str) -> str | list[str]: return self.dict_value[k] - def get(self, k: str, default=None): + def get(self, k: str, default: str | None = None) -> str | list[str] | None: return self.dict_value.get(k, default) def ensure_kid(self) -> None: @@ -114,17 +115,17 @@ def ensure_kid(self) -> None: self._dict_value["kid"] = self.thumbprint() @property - def kid(self) -> t.Optional[str]: + def kid(self) -> str | None: """The "kid" value of the JSON Web Key.""" - return self.get("kid") + return t.cast(t.Optional[str], self.get("kid")) @property - def alg(self) -> t.Optional[str]: + def alg(self) -> str | None: """The "alg" value of the JSON Web Key.""" - return self.get("alg") + return t.cast(t.Optional[str], self.get("alg")) @property - def raw_value(self): + def raw_value(self) -> t.Any: raise NotImplementedError() @property @@ -220,13 +221,13 @@ def check_key_op(self, operation: str) -> None: if reg.private and not self.is_private: raise UnsupportedKeyOperationError(f'Invalid key_op "{operation}" for public key') - @overload + @t.overload def get_op_key(self, operation: t.Literal["verify", "encrypt", "wrapKey", "deriveKey"]) -> NativePublicKey: ... - @overload + @t.overload def get_op_key(self, operation: t.Literal["sign", "decrypt", "unwrapKey"]) -> NativePrivateKey: ... - def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKey]: + def get_op_key(self, operation: str) -> NativePublicKey | NativePrivateKey: self.check_key_op(operation) reg = self.operation_registry[operation] if reg.private: @@ -235,7 +236,7 @@ def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKe return self.public_key @classmethod - def validate_dict_key(cls, data: DictKey): + def validate_dict_key(cls, data: DictKey) -> None: cls.binding.validate_dict_key_registry(data, cls.param_registry) cls.binding.validate_dict_key_registry(data, cls.value_registry) cls.binding.validate_dict_key_use_operations(data) @@ -257,7 +258,7 @@ def import_key( @classmethod def generate_key( cls: t.Type[GenericKey], - size_or_crv, + size_or_crv: t.Any, parameters: t.Optional[KeyParameters] = None, private: bool = True, auto_kid: bool = False) -> GenericKey: @@ -312,5 +313,5 @@ def curve_name(self) -> str: pass @abstractmethod - def exchange_derive_key(self, key) -> bytes: + def exchange_derive_key(self, key: t.Any) -> bytes: pass diff --git a/src/joserfc/rfc7517/pem.py b/src/joserfc/rfc7517/pem.py index 6d05e9c..778a2f2 100644 --- a/src/joserfc/rfc7517/pem.py +++ b/src/joserfc/rfc7517/pem.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Any, Literal, cast from abc import ABCMeta, abstractmethod from cryptography.hazmat.primitives.serialization import ( load_pem_private_key, @@ -73,15 +73,17 @@ def dump_pem_key( encryption_algorithm = NoEncryption() else: encryption_algorithm = BestAvailableEncryption(to_bytes(password)) - return key.private_bytes( + value = key.private_bytes( encoding=encoding_enum, format=PrivateFormat.PKCS8, encryption_algorithm=encryption_algorithm, ) - return key.public_bytes( - encoding=encoding_enum, - format=PublicFormat.SubjectPublicKeyInfo, - ) + else: + value = key.public_bytes( + encoding=encoding_enum, + format=PublicFormat.SubjectPublicKeyInfo, + ) + return cast(bytes, value) class CryptographyBinding(NativeKeyBinding, metaclass=ABCMeta): @@ -90,8 +92,10 @@ class CryptographyBinding(NativeKeyBinding, metaclass=ABCMeta): @classmethod def convert_raw_key_to_dict(cls, raw_key: Any, private: bool) -> DictKey: if private: - return cls.export_private_key(raw_key) - return cls.export_public_key(raw_key) + value = cls.export_private_key(raw_key) + else: + value = cls.export_public_key(raw_key) + return cast(DictKey, value) @classmethod def import_from_dict(cls, value: DictKey) -> Any: @@ -122,20 +126,20 @@ def as_bytes( @staticmethod @abstractmethod - def import_private_key(value) -> Any: + def import_private_key(value: Any) -> Any: pass @staticmethod @abstractmethod - def import_public_key(value) -> Any: + def import_public_key(value: Any) -> Any: pass @staticmethod @abstractmethod - def export_private_key(value) -> Any: + def export_private_key(value: Any) -> Any: pass @staticmethod @abstractmethod - def export_public_key(value) -> Any: + def export_public_key(value: Any) -> Any: pass diff --git a/src/joserfc/rfc7518/ec_key.py b/src/joserfc/rfc7518/ec_key.py index 17a35e6..6e71cf3 100644 --- a/src/joserfc/rfc7518/ec_key.py +++ b/src/joserfc/rfc7518/ec_key.py @@ -22,7 +22,7 @@ from ..registry import KeyParameter ECDictKey = t.TypedDict("ECDictKey", { - "crv": t.Literal["P-256", "P-384", "P-521"], + "crv": str, "x": str, "y": str, "d": str, # optional @@ -56,7 +56,7 @@ def import_private_key(obj: ECDictKey) -> EllipticCurvePrivateKey: return private_numbers.private_key(default_backend()) @staticmethod - def export_private_key(key: EllipticCurvePrivateKey) -> dict[str, str]: + def export_private_key(key: EllipticCurvePrivateKey) -> ECDictKey: numbers = key.private_numbers() return { "crv": CURVES_DSS[key.curve.name], @@ -76,7 +76,7 @@ def import_public_key(obj: ECDictKey) -> EllipticCurvePublicKey: return public_numbers.public_key(default_backend()) @staticmethod - def export_public_key(key: EllipticCurvePublicKey) -> dict[str, str]: + def export_public_key(key: EllipticCurvePublicKey) -> ECDictKey: numbers = key.public_numbers() return { "crv": CURVES_DSS[numbers.curve.name], diff --git a/src/joserfc/rfc7797/json.py b/src/joserfc/rfc7797/json.py index 632413c..a78e1b2 100644 --- a/src/joserfc/rfc7797/json.py +++ b/src/joserfc/rfc7797/json.py @@ -11,7 +11,7 @@ serialize_json as _serialize_json, deserialize_json as _deserialize_json, ) -from ..jwk import KeyFlexible, guess_key +from ..jwk import Key, KeyFlexible, guess_key from ..util import ( to_bytes, to_str, @@ -84,7 +84,7 @@ def deserialize_json( payload_segment = obj.segments["payload"] - def find_key(d: t.Any): + def find_key(d: t.Any) -> Key: return guess_key(public_key, d) assert obj.signature is not None