Skip to content

Commit

Permalink
feat: improvements on type hints
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Hsiaoming Yang <me@lepture.com>
  • Loading branch information
Viicos and lepture authored Jul 18, 2023
1 parent 3dc631b commit 77fecad
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 37 deletions.
9 changes: 6 additions & 3 deletions src/joserfc/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,21 @@ def deserialize_json(

find_key = lambda d: guess_key(public_key, d)
if "signatures" in value:
general_obj = extract_general_json(value) # type: ignore
general_obj = extract_general_json(value)
if not verify_general_json(general_obj, registry, find_key):
raise BadSignatureError()
return general_obj
else:
flattened_obj = extract_flattened_json(value) # type: ignore
flattened_obj = extract_flattened_json(value)
if not verify_flattened_json(flattened_obj, registry, find_key):
raise BadSignatureError()
return flattened_obj


def detach_content(value: t.Union[str, t.Dict]):
DetachValue = t.TypeVar("DetachValue", str, t.Dict[str, t.Any])


def detach_content(value: DetachValue) -> DetachValue:
"""In some contexts, it is useful to integrity-protect content that is
not itself contained in a JWS. This method is an implementation of
https://www.rfc-editor.org/rfc/rfc7515#appendix-F
Expand Down
3 changes: 1 addition & 2 deletions src/joserfc/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ def validate_registry_header(
registry: HeaderRegistryDict,
header: Header,
check_required: bool = True):
for key in registry:
reg = registry[key]
for key, reg in registry.items():
if check_required and reg.required and key not in header:
raise ValueError(f'Required "{key}" is missing in header')
if key in header:
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7515/compact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import binascii
import typing as t
from .model import JWSAlgModel, CompactSignature
from ..errors import DecodeError, MissingAlgorithmError
from ..util import (
Expand Down Expand Up @@ -57,7 +58,7 @@ def detach_compact_content(value: str) -> str:
return ".".join(parts)


def decode_header(header_segment: bytes):
def decode_header(header_segment: bytes) -> t.Dict[str, t.Any]:
try:
protected = json_b64decode(header_segment)
if "alg" not in protected:
Expand Down
5 changes: 3 additions & 2 deletions src/joserfc/rfc7515/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t
import binascii
import copy
from .model import (
HeaderMember,
GeneralJSONSignature,
Expand Down Expand Up @@ -164,9 +165,9 @@ def verify_signature(
return alg.verify(signing_input, sig, key)


def detach_json_content(value):
def detach_json_content(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
# https://www.rfc-editor.org/rfc/rfc7515#appendix-F
rv = value.copy() # don't alter original value
rv = copy.deepcopy(value) # don't alter original value
if "payload" in rv:
del rv["payload"]
return rv
9 changes: 5 additions & 4 deletions src/joserfc/rfc7515/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .types import SegmentsDict, JSONSignatureDict
from ..errors import InvalidKeyTypeError
from ..registry import Header
from ..rfc7517.models import BaseKey


class HeaderMember:
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, member: HeaderMember, payload: bytes):
self.segments: SegmentsDict = {}

@property
def members(self):
def members(self) -> t.List[HeaderMember]:
return [self.member]

def headers(self) -> Header:
Expand Down Expand Up @@ -79,12 +80,12 @@ class JWSAlgModel(object, metaclass=ABCMeta):
algorithm_type = "JWS"
algorithm_location = "sig"

def check_key_type(self, key):
def check_key_type(self, key: BaseKey):
if key.key_type != self.key_type:
raise InvalidKeyTypeError(f'Algorithm "{self.name}" requires "{self.key_type}" key')

@abstractmethod
def sign(self, msg: bytes, key) -> bytes:
def sign(self, msg: bytes, key: t.Any) -> bytes:
"""Sign the text msg with a private/sign key.
:param msg: message bytes to be signed
Expand All @@ -94,7 +95,7 @@ def sign(self, msg: bytes, key) -> bytes:
pass

@abstractmethod
def verify(self, msg: bytes, sig: bytes, key) -> bool:
def verify(self, msg: bytes, sig: bytes, key: t.Any) -> bool:
"""Verify the signature of text msg with a public/verify key.
:param msg: message bytes to be signed
Expand Down
2 changes: 1 addition & 1 deletion src/joserfc/rfc7515/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)


class JWSRegistry(object):
class JWSRegistry:
"""A registry for JSON Web Signature to keep all the supported algorithms.
An instance of ``JWSRegistry`` is usually used together with methods in
``joserfc.jws``.
Expand Down
2 changes: 2 additions & 0 deletions src/joserfc/rfc7515/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class JSONSignatureDict(t.TypedDict, total=False):
signature: str


@t.final
class GeneralJSONSerialization(t.TypedDict):
payload: str
signatures: t.List[JSONSignatureDict]


@t.final
class FlattenedJSONSerialization(t.TypedDict, total=False):
payload: str
protected: str
Expand Down
6 changes: 3 additions & 3 deletions src/joserfc/rfc7517/keyset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing as t
import random
from .models import BaseKey as Key, SymmetricKey
from .types import KeySetDict, KeyParameters
from .types import KeyDict, KeySetDict, KeyParameters
from .registry import JWKRegistry


Expand All @@ -18,8 +18,8 @@ def __init__(self, keys: t.List[Key]):
def __iter__(self) -> t.Iterator[Key]:
return iter(self.keys)

def as_dict(self, private=None, **params):
keys = []
def as_dict(self, private: t.Optional[bool] = None, **params: t.Any) -> KeySetDict:
keys: t.List[KeyDict] = []

for key in self.keys:
# trigger key to generate kid via thumbprint
Expand Down
24 changes: 16 additions & 8 deletions src/joserfc/rfc7517/pem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
NoEncryption,
)
from cryptography.hazmat.backends import default_backend
from .models import NativeKeyBinding
from .models import NativeKeyBinding, BaseKey
from .types import KeyDict
from ..util import to_bytes

Expand All @@ -41,7 +41,11 @@ def load_pem_key(raw: bytes, ssh_type: t.Optional[bytes] = None, password: t.Opt
return key


def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes:
def dump_pem_key(
key: t.Any,
encoding: t.Optional[t.Literal["PEM", "DER"]] = None,
private: t.Optional[bool] = False,
password: t.Optional[t.Any] = None) -> bytes:
"""Export key into PEM/DER format bytes.
:param key: native cryptography key
Expand All @@ -52,9 +56,9 @@ def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes:
"""

if encoding is None or encoding == "PEM":
encoding = Encoding.PEM
encoding_enum = Encoding.PEM
elif encoding == "DER":
encoding = Encoding.DER
encoding_enum = Encoding.DER
else: # pragma: no cover
raise ValueError("Invalid encoding: {!r}".format(encoding))

Expand All @@ -65,12 +69,12 @@ def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes:
else:
encryption_algorithm = BestAvailableEncryption(to_bytes(password))
return key.private_bytes(
encoding=encoding,
encoding=encoding_enum,
format=PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)
return key.public_bytes(
encoding=encoding,
encoding=encoding_enum,
format=PublicFormat.SubjectPublicKeyInfo,
)

Expand All @@ -91,13 +95,17 @@ def import_from_dict(cls, value: KeyDict):
return cls.import_public_key(value)

@classmethod
def import_from_bytes(cls, value: bytes, password=None):
def import_from_bytes(cls, value: bytes, password: t.Optional[t.Any] = None):
if password is not None:
password = to_bytes(password)
return load_pem_key(value, cls.ssh_type, password)

@staticmethod
def as_bytes(key, encoding=None, private=None, password=None) -> bytes:
def as_bytes(
key: BaseKey,
encoding: t.Optional[t.Literal["PEM", "DER"]] = None,
private: t.Optional[bool] = False,
password: t.Optional[t.Any] = None) -> bytes:
if private is True:
return dump_pem_key(key.private_key, encoding, private, password)
elif private is False:
Expand Down
17 changes: 9 additions & 8 deletions src/joserfc/rfc7518/jws_algs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3
"""

import typing as t
import hmac
import hashlib
import typing as t
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature,
Expand All @@ -21,6 +21,7 @@
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.exceptions import InvalidSignature
from ..rfc7515.model import JWSAlgModel
from ..rfc7517.models import BaseKey
from .oct_key import OctKey
from .rsa_key import RSAKey
from .ec_key import ECKey
Expand All @@ -31,10 +32,10 @@ class NoneAlgModel(JWSAlgModel):
name = "none"
description = "No digital signature or MAC performed"

def sign(self, msg, key):
def sign(self, msg: bytes, key: BaseKey):
return b""

def verify(self, msg, sig, key):
def verify(self, msg: bytes, sig: bytes, key: BaseKey):
return False


Expand All @@ -50,7 +51,7 @@ class HMACAlgModel(JWSAlgModel):
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512

def __init__(self, sha_type: int, recommended: bool = False):
def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False):
self.name = f"HS{sha_type}"
self.description = f"HMAC using SHA-{sha_type}"
self.recommended = recommended
Expand Down Expand Up @@ -81,7 +82,7 @@ class RSAAlgModel(JWSAlgModel):
SHA512 = hashes.SHA512
padding = padding.PKCS1v15()

def __init__(self, sha_type: int, recommended: bool = False):
def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False):
self.name = f"RS{sha_type}"
self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}"
self.recommended = recommended
Expand Down Expand Up @@ -113,14 +114,14 @@ class ECAlgModel(JWSAlgModel):
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512

def __init__(self, name: str, curve: str, sha_type: int, recommended: bool = False):
def __init__(self, name: str, curve: str, sha_type: t.Literal[256, 384, 512], recommended: bool = False):
self.name = name
self.curve = curve
self.description = f"ECDSA using {self.curve} and SHA-{sha_type}"
self.recommended = recommended
self.hash_alg = getattr(self, f"SHA{sha_type}")

def _check_key(self, key: ECKey):
def _check_key(self, key: ECKey) -> ECKey:
if key.curve_name != self.curve:
raise ValueError(f'Key for "{self.name}" not supported, only "{self.curve}" allowed')
return key
Expand Down Expand Up @@ -166,7 +167,7 @@ class RSAPSSAlgModel(JWSAlgModel):
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512

def __init__(self, sha_type: int):
def __init__(self, sha_type: t.Literal[256, 384, 512]):
self.name = f"PS{sha_type}"
self.description = f"RSASSA-PSS using SHA-{sha_type} and MGF1 with SHA-{sha_type}"
self.hash_alg = getattr(self, f"SHA{sha_type}")
Expand Down
11 changes: 6 additions & 5 deletions src/joserfc/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import typing as t
import base64
import struct
import json


def to_bytes(x, charset="utf-8", errors="strict") -> bytes:
def to_bytes(x: t.Any, charset: str = "utf-8", errors: str = "strict") -> bytes:
if isinstance(x, bytes):
return x
if isinstance(x, str):
Expand All @@ -13,13 +14,13 @@ def to_bytes(x, charset="utf-8", errors="strict") -> bytes:
return bytes(x)


def to_unicode(x, charset="utf-8") -> str:
def to_unicode(x, charset: str = "utf-8") -> str:
if isinstance(x, bytes):
return x.decode(charset)
return x


def json_dumps(data: dict, ensure_ascii=False):
def json_dumps(data: t.Any, ensure_ascii: bool = False) -> str:
return json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":"))


Expand All @@ -46,11 +47,11 @@ def int_to_base64(num: int) -> str:
return urlsafe_b64encode(s).decode("utf-8", "strict")


def json_b64encode(text, charset="utf-8") -> bytes:
def json_b64encode(text, charset: str = "utf-8") -> bytes:
if isinstance(text, dict):
text = json_dumps(text)
return urlsafe_b64encode(to_bytes(text, charset))


def json_b64decode(text, charset="utf-8"):
def json_b64decode(text, charset="utf-8") -> t.Dict[str, t.Any]:
return json.loads(urlsafe_b64decode(to_bytes(text, charset)))

0 comments on commit 77fecad

Please sign in to comment.