From 77fecad9974672808af9038840693bbf1b739c8a Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 19 Jul 2023 00:22:34 +0200 Subject: [PATCH] feat: improvements on type hints --------- Co-authored-by: Hsiaoming Yang --- src/joserfc/jws.py | 9 ++++++--- src/joserfc/registry.py | 3 +-- src/joserfc/rfc7515/compact.py | 3 ++- src/joserfc/rfc7515/json.py | 5 +++-- src/joserfc/rfc7515/model.py | 9 +++++---- src/joserfc/rfc7515/registry.py | 2 +- src/joserfc/rfc7515/types.py | 2 ++ src/joserfc/rfc7517/keyset.py | 6 +++--- src/joserfc/rfc7517/pem.py | 24 ++++++++++++++++-------- src/joserfc/rfc7518/jws_algs.py | 17 +++++++++-------- src/joserfc/util.py | 11 ++++++----- 11 files changed, 54 insertions(+), 37 deletions(-) diff --git a/src/joserfc/jws.py b/src/joserfc/jws.py index 7a8c5aa..ddcb204 100644 --- a/src/joserfc/jws.py +++ b/src/joserfc/jws.py @@ -268,18 +268,21 @@ def deserialize_json( find_key = lambda d: guess_key(public_key, d) if "signatures" in value: - general_obj = extract_general_json(value) # type: ignore + general_obj = extract_general_json(value) if not verify_general_json(general_obj, registry, find_key): raise BadSignatureError() return general_obj else: - flattened_obj = extract_flattened_json(value) # type: ignore + flattened_obj = extract_flattened_json(value) if not verify_flattened_json(flattened_obj, registry, find_key): raise BadSignatureError() return flattened_obj -def detach_content(value: t.Union[str, t.Dict]): +DetachValue = t.TypeVar("DetachValue", str, t.Dict[str, t.Any]) + + +def detach_content(value: DetachValue) -> DetachValue: """In some contexts, it is useful to integrity-protect content that is not itself contained in a JWS. This method is an implementation of https://www.rfc-editor.org/rfc/rfc7515#appendix-F diff --git a/src/joserfc/registry.py b/src/joserfc/registry.py index e84cb1e..5bda369 100644 --- a/src/joserfc/registry.py +++ b/src/joserfc/registry.py @@ -181,8 +181,7 @@ def validate_registry_header( registry: HeaderRegistryDict, header: Header, check_required: bool = True): - for key in registry: - reg = registry[key] + for key, reg in registry.items(): if check_required and reg.required and key not in header: raise ValueError(f'Required "{key}" is missing in header') if key in header: diff --git a/src/joserfc/rfc7515/compact.py b/src/joserfc/rfc7515/compact.py index 00b3035..63344f4 100644 --- a/src/joserfc/rfc7515/compact.py +++ b/src/joserfc/rfc7515/compact.py @@ -1,4 +1,5 @@ import binascii +import typing as t from .model import JWSAlgModel, CompactSignature from ..errors import DecodeError, MissingAlgorithmError from ..util import ( @@ -57,7 +58,7 @@ def detach_compact_content(value: str) -> str: return ".".join(parts) -def decode_header(header_segment: bytes): +def decode_header(header_segment: bytes) -> t.Dict[str, t.Any]: try: protected = json_b64decode(header_segment) if "alg" not in protected: diff --git a/src/joserfc/rfc7515/json.py b/src/joserfc/rfc7515/json.py index d09f3ec..98f9c9a 100644 --- a/src/joserfc/rfc7515/json.py +++ b/src/joserfc/rfc7515/json.py @@ -1,5 +1,6 @@ import typing as t import binascii +import copy from .model import ( HeaderMember, GeneralJSONSignature, @@ -164,9 +165,9 @@ def verify_signature( return alg.verify(signing_input, sig, key) -def detach_json_content(value): +def detach_json_content(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: # https://www.rfc-editor.org/rfc/rfc7515#appendix-F - rv = value.copy() # don't alter original value + rv = copy.deepcopy(value) # don't alter original value if "payload" in rv: del rv["payload"] return rv diff --git a/src/joserfc/rfc7515/model.py b/src/joserfc/rfc7515/model.py index b189f53..fe6773a 100644 --- a/src/joserfc/rfc7515/model.py +++ b/src/joserfc/rfc7515/model.py @@ -3,6 +3,7 @@ from .types import SegmentsDict, JSONSignatureDict from ..errors import InvalidKeyTypeError from ..registry import Header +from ..rfc7517.models import BaseKey class HeaderMember: @@ -50,7 +51,7 @@ def __init__(self, member: HeaderMember, payload: bytes): self.segments: SegmentsDict = {} @property - def members(self): + def members(self) -> t.List[HeaderMember]: return [self.member] def headers(self) -> Header: @@ -79,12 +80,12 @@ class JWSAlgModel(object, metaclass=ABCMeta): algorithm_type = "JWS" algorithm_location = "sig" - def check_key_type(self, key): + def check_key_type(self, key: BaseKey): if key.key_type != self.key_type: raise InvalidKeyTypeError(f'Algorithm "{self.name}" requires "{self.key_type}" key') @abstractmethod - def sign(self, msg: bytes, key) -> bytes: + def sign(self, msg: bytes, key: t.Any) -> bytes: """Sign the text msg with a private/sign key. :param msg: message bytes to be signed @@ -94,7 +95,7 @@ def sign(self, msg: bytes, key) -> bytes: pass @abstractmethod - def verify(self, msg: bytes, sig: bytes, key) -> bool: + def verify(self, msg: bytes, sig: bytes, key: t.Any) -> bool: """Verify the signature of text msg with a public/verify key. :param msg: message bytes to be signed diff --git a/src/joserfc/rfc7515/registry.py b/src/joserfc/rfc7515/registry.py index d97a1a4..31e0872 100644 --- a/src/joserfc/rfc7515/registry.py +++ b/src/joserfc/rfc7515/registry.py @@ -10,7 +10,7 @@ ) -class JWSRegistry(object): +class JWSRegistry: """A registry for JSON Web Signature to keep all the supported algorithms. An instance of ``JWSRegistry`` is usually used together with methods in ``joserfc.jws``. diff --git a/src/joserfc/rfc7515/types.py b/src/joserfc/rfc7515/types.py index d96bae9..b54a1bf 100644 --- a/src/joserfc/rfc7515/types.py +++ b/src/joserfc/rfc7515/types.py @@ -27,11 +27,13 @@ class JSONSignatureDict(t.TypedDict, total=False): signature: str +@t.final class GeneralJSONSerialization(t.TypedDict): payload: str signatures: t.List[JSONSignatureDict] +@t.final class FlattenedJSONSerialization(t.TypedDict, total=False): payload: str protected: str diff --git a/src/joserfc/rfc7517/keyset.py b/src/joserfc/rfc7517/keyset.py index fe8f471..adbef63 100644 --- a/src/joserfc/rfc7517/keyset.py +++ b/src/joserfc/rfc7517/keyset.py @@ -1,7 +1,7 @@ import typing as t import random from .models import BaseKey as Key, SymmetricKey -from .types import KeySetDict, KeyParameters +from .types import KeyDict, KeySetDict, KeyParameters from .registry import JWKRegistry @@ -18,8 +18,8 @@ def __init__(self, keys: t.List[Key]): def __iter__(self) -> t.Iterator[Key]: return iter(self.keys) - def as_dict(self, private=None, **params): - keys = [] + def as_dict(self, private: t.Optional[bool] = None, **params: t.Any) -> KeySetDict: + keys: t.List[KeyDict] = [] for key in self.keys: # trigger key to generate kid via thumbprint diff --git a/src/joserfc/rfc7517/pem.py b/src/joserfc/rfc7517/pem.py index 375386a..ec5dfe2 100644 --- a/src/joserfc/rfc7517/pem.py +++ b/src/joserfc/rfc7517/pem.py @@ -15,7 +15,7 @@ NoEncryption, ) from cryptography.hazmat.backends import default_backend -from .models import NativeKeyBinding +from .models import NativeKeyBinding, BaseKey from .types import KeyDict from ..util import to_bytes @@ -41,7 +41,11 @@ def load_pem_key(raw: bytes, ssh_type: t.Optional[bytes] = None, password: t.Opt return key -def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes: +def dump_pem_key( + key: t.Any, + encoding: t.Optional[t.Literal["PEM", "DER"]] = None, + private: t.Optional[bool] = False, + password: t.Optional[t.Any] = None) -> bytes: """Export key into PEM/DER format bytes. :param key: native cryptography key @@ -52,9 +56,9 @@ def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes: """ if encoding is None or encoding == "PEM": - encoding = Encoding.PEM + encoding_enum = Encoding.PEM elif encoding == "DER": - encoding = Encoding.DER + encoding_enum = Encoding.DER else: # pragma: no cover raise ValueError("Invalid encoding: {!r}".format(encoding)) @@ -65,12 +69,12 @@ def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes: else: encryption_algorithm = BestAvailableEncryption(to_bytes(password)) return key.private_bytes( - encoding=encoding, + encoding=encoding_enum, format=PrivateFormat.PKCS8, encryption_algorithm=encryption_algorithm, ) return key.public_bytes( - encoding=encoding, + encoding=encoding_enum, format=PublicFormat.SubjectPublicKeyInfo, ) @@ -91,13 +95,17 @@ def import_from_dict(cls, value: KeyDict): return cls.import_public_key(value) @classmethod - def import_from_bytes(cls, value: bytes, password=None): + def import_from_bytes(cls, value: bytes, password: t.Optional[t.Any] = None): if password is not None: password = to_bytes(password) return load_pem_key(value, cls.ssh_type, password) @staticmethod - def as_bytes(key, encoding=None, private=None, password=None) -> bytes: + def as_bytes( + key: BaseKey, + encoding: t.Optional[t.Literal["PEM", "DER"]] = None, + private: t.Optional[bool] = False, + password: t.Optional[t.Any] = None) -> bytes: if private is True: return dump_pem_key(key.private_key, encoding, private, password) elif private is False: diff --git a/src/joserfc/rfc7518/jws_algs.py b/src/joserfc/rfc7518/jws_algs.py index 36ef60c..1373921 100644 --- a/src/joserfc/rfc7518/jws_algs.py +++ b/src/joserfc/rfc7518/jws_algs.py @@ -9,9 +9,9 @@ .. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3 """ -import typing as t import hmac import hashlib +import typing as t from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.utils import ( decode_dss_signature, @@ -21,6 +21,7 @@ from cryptography.hazmat.primitives.asymmetric import padding from cryptography.exceptions import InvalidSignature from ..rfc7515.model import JWSAlgModel +from ..rfc7517.models import BaseKey from .oct_key import OctKey from .rsa_key import RSAKey from .ec_key import ECKey @@ -31,10 +32,10 @@ class NoneAlgModel(JWSAlgModel): name = "none" description = "No digital signature or MAC performed" - def sign(self, msg, key): + def sign(self, msg: bytes, key: BaseKey): return b"" - def verify(self, msg, sig, key): + def verify(self, msg: bytes, sig: bytes, key: BaseKey): return False @@ -50,7 +51,7 @@ class HMACAlgModel(JWSAlgModel): SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - def __init__(self, sha_type: int, recommended: bool = False): + def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False): self.name = f"HS{sha_type}" self.description = f"HMAC using SHA-{sha_type}" self.recommended = recommended @@ -81,7 +82,7 @@ class RSAAlgModel(JWSAlgModel): SHA512 = hashes.SHA512 padding = padding.PKCS1v15() - def __init__(self, sha_type: int, recommended: bool = False): + def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False): self.name = f"RS{sha_type}" self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}" self.recommended = recommended @@ -113,14 +114,14 @@ class ECAlgModel(JWSAlgModel): SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, name: str, curve: str, sha_type: int, recommended: bool = False): + def __init__(self, name: str, curve: str, sha_type: t.Literal[256, 384, 512], recommended: bool = False): self.name = name self.curve = curve self.description = f"ECDSA using {self.curve} and SHA-{sha_type}" self.recommended = recommended self.hash_alg = getattr(self, f"SHA{sha_type}") - def _check_key(self, key: ECKey): + def _check_key(self, key: ECKey) -> ECKey: if key.curve_name != self.curve: raise ValueError(f'Key for "{self.name}" not supported, only "{self.curve}" allowed') return key @@ -166,7 +167,7 @@ class RSAPSSAlgModel(JWSAlgModel): SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, sha_type: int): + def __init__(self, sha_type: t.Literal[256, 384, 512]): self.name = f"PS{sha_type}" self.description = f"RSASSA-PSS using SHA-{sha_type} and MGF1 with SHA-{sha_type}" self.hash_alg = getattr(self, f"SHA{sha_type}") diff --git a/src/joserfc/util.py b/src/joserfc/util.py index 59c0c5e..bba76c2 100644 --- a/src/joserfc/util.py +++ b/src/joserfc/util.py @@ -1,9 +1,10 @@ +import typing as t import base64 import struct import json -def to_bytes(x, charset="utf-8", errors="strict") -> bytes: +def to_bytes(x: t.Any, charset: str = "utf-8", errors: str = "strict") -> bytes: if isinstance(x, bytes): return x if isinstance(x, str): @@ -13,13 +14,13 @@ def to_bytes(x, charset="utf-8", errors="strict") -> bytes: return bytes(x) -def to_unicode(x, charset="utf-8") -> str: +def to_unicode(x, charset: str = "utf-8") -> str: if isinstance(x, bytes): return x.decode(charset) return x -def json_dumps(data: dict, ensure_ascii=False): +def json_dumps(data: t.Any, ensure_ascii: bool = False) -> str: return json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":")) @@ -46,11 +47,11 @@ def int_to_base64(num: int) -> str: return urlsafe_b64encode(s).decode("utf-8", "strict") -def json_b64encode(text, charset="utf-8") -> bytes: +def json_b64encode(text, charset: str = "utf-8") -> bytes: if isinstance(text, dict): text = json_dumps(text) return urlsafe_b64encode(to_bytes(text, charset)) -def json_b64decode(text, charset="utf-8"): +def json_b64decode(text, charset="utf-8") -> t.Dict[str, t.Any]: return json.loads(urlsafe_b64decode(to_bytes(text, charset)))