diff --git a/src/joserfc/jws.py b/src/joserfc/jws.py index ee42a7e..a934ed8 100644 --- a/src/joserfc/jws.py +++ b/src/joserfc/jws.py @@ -1,10 +1,12 @@ import typing as t +from typing import overload from .rfc7515 import types from .rfc7515.model import ( JWSAlgModel, HeaderMember, CompactSignature, - JSONSignature, + GeneralJSONSignature, + FlattenedJSONSignature, ) from .rfc7515.registry import ( JWSRegistry, @@ -17,15 +19,18 @@ detach_compact_content, ) from .rfc7515.json import ( - construct_json_signature, - sign_json, - verify_json, - extract_json, + sign_general_json, + sign_flattened_json, + verify_general_json, + verify_flattened_json, + extract_general_json, + extract_flattened_json, detach_json_content, ) from .rfc7515.types import ( HeaderDict, - JSONSerialization, + GeneralJSONSerialization, + FlattenedJSONSerialization, ) from .rfc7518.jws_algs import JWS_ALGORITHMS from .rfc8037.jws_eddsa import EdDSA @@ -39,9 +44,13 @@ "types", "JWSAlgModel", "JWSRegistry", + "HeaderDict", "HeaderMember", "CompactSignature", - "JSONSignature", + "GeneralJSONSignature", + "FlattenedJSONSignature", + "GeneralJSONSerialization", + "FlattenedJSONSerialization", "serialize_compact", "deserialize_compact", "extract_compact", @@ -166,12 +175,30 @@ def deserialize_compact( return obj +@overload +def serialize_json( + members: t.List[HeaderDict], + payload: t.AnyStr, + private_key: KeyFlexible, + algorithms: t.Optional[t.List[str]] = None, + registry: t.Optional[JWSRegistry] = None) -> GeneralJSONSignature: ... + + +@overload +def serialize_json( + members: HeaderDict, + payload: t.AnyStr, + private_key: KeyFlexible, + algorithms: t.Optional[t.List[str]] = None, + registry: t.Optional[JWSRegistry] = None) -> FlattenedJSONSignature: ... + + def serialize_json( members: t.Union[HeaderDict, t.List[HeaderDict]], payload: t.AnyStr, private_key: KeyFlexible, algorithms: t.Optional[t.List[str]] = None, - registry: t.Optional[JWSRegistry] = None) -> JSONSerialization: + registry: t.Optional[JWSRegistry] = None): """Generate a JWS JSON Serialization (in dict). The JWS JSON Serialization represents digitally signed or MACed content as a JSON object. This representation is neither optimized for compactness nor URL-safe. @@ -201,14 +228,16 @@ def serialize_json( if registry is None: registry = construct_registry(algorithms) - obj = construct_json_signature(members, payload, registry) - obj.segments["payload"] = urlsafe_b64encode(obj.payload) find_key = lambda d: guess_key(private_key, d) - return sign_json(obj, registry, find_key) + _payload = to_bytes(payload) + if isinstance(members, list): + return sign_general_json(members, _payload, registry, find_key) + else: + return sign_flattened_json(members, _payload, registry, find_key) def validate_json( - obj: JSONSignature, + obj: t.Union[GeneralJSONSignature, FlattenedJSONSignature], public_key: KeyFlexible, algorithms: t.Optional[t.List[str]] = None, registry: t.Optional[JWSRegistry] = None): @@ -223,16 +252,22 @@ def validate_json( """ if registry is None: registry = construct_registry(algorithms) + find_key = lambda d: guess_key(public_key, d) - if not verify_json(obj, registry, find_key): + valid: bool + if isinstance(obj, GeneralJSONSignature): + valid = verify_general_json(obj, registry, find_key) + else: + valid = verify_flattened_json(obj, registry, find_key) + if not valid: raise BadSignatureError() def deserialize_json( - value: JSONSerialization, + value: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization], public_key: KeyFlexible, algorithms: t.Optional[t.List[str]] = None, - registry: t.Optional[JWSRegistry] = None) -> JSONSignature: + registry: t.Optional[JWSRegistry] = None) -> t.Union[GeneralJSONSignature, FlattenedJSONSignature]: """Extract and validate the JWS (in string) with the given key. :param value: a dict of the JSON signature @@ -241,12 +276,15 @@ def deserialize_json( :param registry: a JWSRegistry to use :return: object of the SignatureData """ - obj = extract_json(value) + if "signatures" in value: + obj = extract_general_json(value) + else: + obj = extract_flattened_json(value) validate_json(obj, public_key, algorithms, registry) return obj -def detach_content(value: t.Union[str, JSONSerialization]): +def detach_content(value: t.Union[str, t.Dict]): """In some contexts, it is useful to integrity-protect content that is not itself contained in a JWS. This method is an implementation of https://www.rfc-editor.org/rfc/rfc7515#appendix-F diff --git a/src/joserfc/rfc7515/json.py b/src/joserfc/rfc7515/json.py index 8c85003..a398f9c 100644 --- a/src/joserfc/rfc7515/json.py +++ b/src/joserfc/rfc7515/json.py @@ -1,78 +1,67 @@ import typing as t import binascii -from .model import JWSAlgModel, HeaderMember, JSONSignature +from .model import ( + HeaderMember, + GeneralJSONSignature, + FlattenedJSONSignature, +) from .types import ( HeaderDict, JSONSignatureDict, - JSONSerialization, GeneralJSONSerialization, FlattenedJSONSerialization, ) from .registry import JWSRegistry from ..util import ( - to_bytes, json_b64encode, json_b64decode, urlsafe_b64encode, urlsafe_b64decode, ) from ..errors import DecodeError +from ..rfc7517.models import BaseKey -def construct_json_signature( - members: t.Union[HeaderDict, t.List[HeaderDict]], - payload: t.AnyStr, - registry: JWSRegistry) -> JSONSignature: - if isinstance(members, dict): - flattened = True - __check_member(registry, members) - members = [members] - else: - flattened = False - for member in members: - __check_member(registry, member) - - members = [HeaderMember(**member) for member in members] - payload = to_bytes(payload) - obj = JSONSignature(members, payload) - obj.flattened = flattened - return obj - - -def __check_member(registry: JWSRegistry, member: HeaderDict): - header = {} - if "protected" in member: - header.update(member["protected"]) - if "header" in member: - header.update(member["header"]) - registry.check_header(header) +FindKey = t.Callable[[HeaderMember], BaseKey] -def sign_json(obj: JSONSignature, registry: JWSRegistry, find_key) -> JSONSerialization: - signatures: t.List[JSONSignatureDict] = [] - - payload_segment = obj.segments["payload"] - for member in obj.members: - headers = member.headers() - registry.check_header(headers) - alg = registry.get_alg(headers["alg"]) - key = find_key(member) - key.check_use("sig") - alg.check_key_type(key) - signature = __sign_member(payload_segment, member, alg, key) - signatures.append(signature) - - rv = {"payload": payload_segment.decode("utf-8")} - if obj.flattened and len(signatures) == 1: - rv.update(dict(signatures[0])) - else: - rv["signatures"] = signatures +def sign_general_json( + members: t.List[HeaderDict], + payload: t.AnyStr, + registry: JWSRegistry, + find_key: FindKey) -> GeneralJSONSerialization: - obj.signatures = signatures - return rv + payload_segment = urlsafe_b64encode(payload) + signatures: t.List[JSONSignatureDict] = [ + __sign_member(payload_segment, HeaderMember(**member), registry, find_key) + for member in members + ] + payload = payload_segment.decode("utf-8") + return {"payload": payload, "signatures": signatures} -def __sign_member(payload_segment, member: HeaderMember, alg: JWSAlgModel, key) -> JSONSignatureDict: +def sign_flattened_json( + member: HeaderDict, + payload: t.AnyStr, + registry: JWSRegistry, + find_key: FindKey) -> FlattenedJSONSerialization: + payload_segment = urlsafe_b64encode(payload) + signature = __sign_member(payload_segment, HeaderMember(**member), registry, find_key) + payload = payload_segment.decode("utf-8") + return {"payload": payload, **signature} + + +def __sign_member( + payload_segment: bytes, + member: HeaderMember, + registry: JWSRegistry, + find_key: FindKey) -> JSONSignatureDict: + headers = member.headers() + registry.check_header(headers) + alg = registry.get_alg(headers["alg"]) + key = find_key(member) + key.check_use("sig") + alg.check_key_type(key) if member.protected: protected_segment = json_b64encode(member.protected) else: @@ -87,73 +76,83 @@ def __sign_member(payload_segment, member: HeaderMember, alg: JWSAlgModel, key) return rv -def extract_json(value: JSONSerialization) -> JSONSignature: - """Extract the JWS JSON Serialization from dict to object. - - :param value: JWS in dict - """ +def extract_general_json(value: GeneralJSONSerialization) -> GeneralJSONSignature: payload_segment: bytes = value["payload"].encode("utf-8") - try: payload = urlsafe_b64decode(payload_segment) except (TypeError, ValueError, binascii.Error): raise DecodeError("Invalid payload") - if "signatures" in value: - flattened = False - value: GeneralJSONSerialization - signatures: t.List[JSONSignatureDict] = value["signatures"] - else: - flattened = True - value: FlattenedJSONSerialization - _sig: JSONSignatureDict = { - "protected": value["protected"], - "signature": value["signature"], - } - if "header" in value: - _sig["header"] = value["header"] - signatures = [_sig] - - members = [] - for sig in signatures: - member = HeaderMember() - if "protected" in sig: - protected_segment = sig["protected"] - member.protected = json_b64decode(protected_segment) - if "header" in sig: - member.header = sig["header"] - members.append(member) - - obj = JSONSignature(members, payload) - obj.segments.update({"payload": payload_segment}) - obj.flattened = flattened + signatures: t.List[JSONSignatureDict] = value["signatures"] + members = [__signature_to_member(sig) for sig in signatures] + obj = GeneralJSONSignature(members, payload) obj.signatures = signatures + obj.segments.update({"payload": payload_segment}) + return obj + + +def extract_flattened_json(value: FlattenedJSONSerialization) -> FlattenedJSONSignature: + payload_segment: bytes = value["payload"].encode("utf-8") + try: + payload = urlsafe_b64decode(payload_segment) + except (TypeError, ValueError, binascii.Error): + raise DecodeError("Invalid payload") + _sig: JSONSignatureDict = { + "protected": value["protected"], + "signature": value["signature"], + } + if "header" in value: + _sig["header"] = value["header"] + + member = __signature_to_member(_sig) + obj = FlattenedJSONSignature(member, payload) + obj.signature = _sig + obj.segments.update({"payload": payload_segment}) return obj -def verify_json(obj: JSONSignature, registry: JWSRegistry, find_key) -> bool: - """Verify the signature of this JSON serialization with the given - algorithm and key. +def __signature_to_member(sig: JSONSignatureDict) -> HeaderMember: + member = HeaderMember() + if "protected" in sig: + protected_segment = sig["protected"] + member.protected = json_b64decode(protected_segment) + if "header" in sig: + member.header = sig["header"] + return member + - :param obj: instance of the SignatureData - :param find_alg: a function to return "alg" model - :param find_key: a function to return public key - """ +def verify_general_json( + obj: GeneralJSONSignature, + registry: JWSRegistry, + find_key: FindKey) -> bool: payload_segment = obj.segments["payload"] for index, signature in enumerate(obj.signatures): member = obj.members[index] - headers = member.headers() - registry.check_header(headers) - alg = registry.get_alg(headers["alg"]) - key = find_key(member) - key.check_use("sig") - alg.check_key_type(key) - if not _verify_signature(signature, payload_segment, alg, key): + if not __verify_signature(member, signature, payload_segment, registry, find_key): return False return True -def _verify_signature(signature: JSONSignatureDict, payload_segment, alg: JWSAlgModel, key) -> bool: +def verify_flattened_json( + obj: FlattenedJSONSignature, + registry: JWSRegistry, + find_key: FindKey) -> bool: + payload_segment = obj.segments["payload"] + return __verify_signature(obj.member, obj.signature, payload_segment, registry, find_key) + + +def __verify_signature( + member: HeaderMember, + signature: JSONSignatureDict, + payload_segment: bytes, + registry: JWSRegistry, + find_key: FindKey) -> bool: + headers = member.headers() + registry.check_header(headers) + alg = registry.get_alg(headers["alg"]) + key = find_key(member) + key.check_use("sig") + alg.check_key_type(key) if "protected" in signature: protected_segment = signature["protected"].encode("utf-8") else: @@ -163,7 +162,7 @@ def _verify_signature(signature: JSONSignatureDict, payload_segment, alg: JWSAlg return alg.verify(signing_input, sig, key) -def detach_json_content(value: JSONSerialization) -> JSONSerialization: +def detach_json_content(value): # https://www.rfc-editor.org/rfc/rfc7515#appendix-F rv = value.copy() # don't alter original value if "payload" in rv: diff --git a/src/joserfc/rfc7515/model.py b/src/joserfc/rfc7515/model.py index 4da1a35..b189f53 100644 --- a/src/joserfc/rfc7515/model.py +++ b/src/joserfc/rfc7515/model.py @@ -40,22 +40,32 @@ def set_kid(self, kid: str): self.protected["kid"] = kid -class JSONSignature: - """JSON Web Signature object for JSON mode. This object is used to - represent the JWS instance. - """ +class FlattenedJSONSignature: + flattened = True + + def __init__(self, member: HeaderMember, payload: bytes): + self.member = member + self.payload = payload + self.signature: t.Optional[JSONSignatureDict] = None + self.segments: SegmentsDict = {} + + @property + def members(self): + return [self.member] + + def headers(self) -> Header: + return self.member.headers() + + +class GeneralJSONSignature: + flattened = False + def __init__(self, members: t.List[HeaderMember], payload: bytes): self.members = members self.payload = payload self.signatures: t.List[JSONSignatureDict] = [] - self.flattened: bool = False self.segments: SegmentsDict = {} - def headers(self) -> Header: - if self.flattened and len(self.members) == 1: - return self.members[0].headers() - raise ValueError("Only flattened JSON Serialization has .headers() method.") - class JWSAlgModel(object, metaclass=ABCMeta): """Interface for JWS algorithm. JWA specification (RFC7518) SHOULD diff --git a/src/joserfc/rfc7515/types.py b/src/joserfc/rfc7515/types.py index 353d508..5f08f0e 100644 --- a/src/joserfc/rfc7515/types.py +++ b/src/joserfc/rfc7515/types.py @@ -10,35 +10,31 @@ "JSONSerialization", ] -SegmentsDict = t.TypedDict("SegmentsDict", { - "header": bytes, - "payload": bytes, - "signature": bytes, -}, total=False) - -HeaderDict = t.TypedDict("HeaderDict", { - "protected": Header, - "header": Header, -}, total=False) - - -JSONSignatureDict = t.TypedDict("JSONSignatureDict", { - "protected": str, - "header": Header, - "signature": str, -}, total=False) - - -GeneralJSONSerialization = t.TypedDict("GeneralJSONSerialization", { - "payload": str, - "signatures": t.List[JSONSignatureDict], -}) - -FlattenedJSONSerialization = t.TypedDict("FlattenedJSONSerialization", { - "payload": str, - "protected": str, - "header": Header, - "signature": str, -}, total=False) - -JSONSerialization = t.Union[GeneralJSONSerialization, FlattenedJSONSerialization] + +class SegmentsDict(t.TypedDict, total=False): + header: bytes + payload: bytes + signature: bytes + + +class HeaderDict(t.TypedDict, total=False): + protected: Header + header: Header + + +class JSONSignatureDict(t.TypedDict, total=False): + protected: str + header: Header + signature: str + + +class GeneralJSONSerialization(t.TypedDict): + payload: str + signatures: t.List[JSONSignatureDict] + + +class FlattenedJSONSerialization(t.TypedDict, total=False): + payload: str + protected: str + header: Header + signature: str diff --git a/src/joserfc/rfc7797/json.py b/src/joserfc/rfc7797/json.py index 1f68b64..3c0f664 100644 --- a/src/joserfc/rfc7797/json.py +++ b/src/joserfc/rfc7797/json.py @@ -2,8 +2,8 @@ from ..jws import ( HeaderDict, HeaderMember, - JSONSignature, - JSONSerialization, + FlattenedJSONSignature, + FlattenedJSONSerialization, JWSRegistry as _JWSRegistry, serialize_json as _serialize_json, deserialize_json as _deserialize_json, @@ -26,7 +26,7 @@ def serialize_json( payload: t.AnyStr, private_key: KeyFlexible, algorithms: t.Optional[t.List[str]] = None, - registry: t.Optional[_JWSRegistry] = None) -> JSONSerialization: + registry: t.Optional[_JWSRegistry] = None) -> FlattenedJSONSerialization: _member = HeaderMember(**member) headers = _member.headers() @@ -65,10 +65,10 @@ def serialize_json( def deserialize_json( - value: JSONSerialization, + value: FlattenedJSONSerialization, public_key: KeyFlexible, algorithms: t.Optional[t.List[str]] = None, - registry: t.Optional[_JWSRegistry] = None) -> JSONSignature: + registry: t.Optional[_JWSRegistry] = None) -> FlattenedJSONSignature: obj = _extract_json(value) if obj is None: return _deserialize_json(value, public_key, algorithms, registry) @@ -76,13 +76,12 @@ def deserialize_json( if registry is None: registry = JWSRegistry(algorithms=algorithms) - member = obj.members[0] - headers = member.headers() + headers = obj.headers() if headers["b64"] is True: return _deserialize_json(value, public_key, registry=registry) registry.check_header(headers) - key = guess_key(public_key, member) + key = guess_key(public_key, obj.member) alg = registry.get_alg(headers["alg"]) signing_input = b".".join([obj.segments["header"], obj.payload]) sig = urlsafe_b64decode(obj.segments["signature"]) @@ -91,7 +90,7 @@ def deserialize_json( return obj -def _extract_json(value: JSONSerialization): +def _extract_json(value: FlattenedJSONSignature): if "signatures" in value: return None @@ -109,10 +108,9 @@ def _extract_json(value: JSONSerialization): return None payload = to_bytes(value["payload"]) - obj = JSONSignature([member], payload) + obj = FlattenedJSONSignature(member, payload) obj.segments = { "header": protected_segment, "signature": to_bytes(value["signature"]), } - obj.flattened = True return obj diff --git a/tests/jws/test_json.py b/tests/jws/test_json.py index 8e7c6ad..98bc302 100644 --- a/tests/jws/test_json.py +++ b/tests/jws/test_json.py @@ -28,7 +28,6 @@ def test_serialize_json(self): obj = deserialize_json(value, key) self.assertEqual(obj.payload, b'hello') - self.assertRaises(ValueError, obj.headers) def test_serialize_with_unprotected_header(self): key: RSAKey = load_key('rsa-openssl-private.pem')