Skip to content

Commit

Permalink
fix(jwk): re-define type Key
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 17, 2023
1 parent e1be85d commit cd78f6c
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 30 deletions.
10 changes: 7 additions & 3 deletions src/joserfc/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
SymmetricKey,
AsymmetricKey,
CurveKey,
Key,
JWKRegistry,
KeySet,
)
Expand All @@ -16,7 +15,8 @@
from .registry import Header


KeyCallable = t.Callable[[t.Any, bool], Key]
Key = t.Union[OctKey, RSAKey, ECKey, OKPKey]
KeyCallable = t.Callable[[t.Any], Key]
KeyFlexible = t.Union[t.AnyStr, Key, KeySet, KeyCallable]

__all__ = [
Expand Down Expand Up @@ -66,6 +66,8 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:
"""
headers = obj.headers()

rv_key: Key

if isinstance(key, (str, bytes)):
rv_key = OctKey.import_key(key)

Expand All @@ -76,7 +78,9 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:
kid = headers.get("kid")
if not kid:
# choose one key by random
rv_key: Key = key.pick_random_key(headers["alg"])
rv_key = key.pick_random_key(headers["alg"])
if rv_key is None:
raise ValueError("Invalid key")
# use side effect to add kid information
obj.set_kid(rv_key.kid)
else:
Expand Down
12 changes: 6 additions & 6 deletions src/joserfc/rfc7516/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import typing as t
from abc import ABCMeta, abstractmethod
from ..rfc7517.models import Key, CurveKey
from ..rfc7517.models import BaseKey, CurveKey
from ..registry import Header, HeaderRegistryDict
from ..errors import InvalidKeyTypeError, InvalidKeyLengthError

Expand All @@ -11,11 +11,11 @@ def __init__(
self,
parent: t.Union["CompactEncryption", "JSONEncryption"],
header: t.Optional[Header] = None,
recipient_key: t.Optional[Key] = None):
recipient_key: t.Optional[BaseKey] = None):
self.__parent = parent
self.header = header
self.recipient_key = recipient_key
self.sender_key: t.Optional[Key] = None
self.sender_key: t.Optional[BaseKey] = None
self.encrypted_key: t.Optional[bytes] = None
self.ephemeral_key: t.Optional[CurveKey] = None
self.segments = {} # store temporary segments
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, protected: Header, plaintext: t.Optional[bytes] = None):
def headers(self):
return self.protected

def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
def attach_recipient(self, key: BaseKey, header: t.Optional[Header] = None):
"""Add a recipient to the JWE Compact Serialization. Please add a key that
comply with the given "alg" value.
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
self.bytes_segments: t.Dict[str, bytes] = {} # store the decoded segments
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments

def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[BaseKey] = None):
"""Add a recipient to the JWE JSON Serialization. Please add a key that
comply with the "alg" to this recipient.
Expand Down Expand Up @@ -183,7 +183,7 @@ class KeyManagement:
def direct_mode(self) -> bool:
return self.key_size is None

def check_key_type(self, key: Key):
def check_key_type(self, key: BaseKey):
if key.key_type not in self.key_types:
raise InvalidKeyTypeError()

Expand Down
3 changes: 1 addition & 2 deletions src/joserfc/rfc7517/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .models import SymmetricKey, AsymmetricKey, CurveKey, Key
from .models import SymmetricKey, AsymmetricKey, CurveKey
from .registry import JWKRegistry
from .keyset import KeySet

__all__ = [
"SymmetricKey",
"AsymmetricKey",
"CurveKey",
"Key",
"JWKRegistry",
"KeySet",
]
8 changes: 4 additions & 4 deletions src/joserfc/rfc7517/keyset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import typing as t
import random
from .models import Key, SymmetricKey
from .models import BaseKey, SymmetricKey
from .types import KeySetDict, KeyParameters
from .registry import JWKRegistry


class KeySet:
def __init__(self, keys: t.List[Key]):
def __init__(self, keys: t.List[BaseKey]):
self.keys = keys

def as_dict(self, private=None, **params):
Expand All @@ -21,7 +21,7 @@ def as_dict(self, private=None, **params):
keys.append(key.as_dict(private=private, **params))
return {"keys": keys}

def get_by_kid(self, kid: t.Optional[str] = None) -> Key:
def get_by_kid(self, kid: t.Optional[str] = None) -> BaseKey:
if kid is None and len(self.keys) == 1:
return self.keys[0]

Expand All @@ -30,7 +30,7 @@ def get_by_kid(self, kid: t.Optional[str] = None) -> Key:
return key
raise ValueError(f'No key for kid: "{kid}"')

def pick_random_key(self, algorithm: str) -> t.Optional[Key]:
def pick_random_key(self, algorithm: str) -> t.Optional[BaseKey]:
key_types = JWKRegistry.algorithm_key_types.get(algorithm)
if key_types:
keys = [k for k in self.keys if k.key_type in key_types]
Expand Down
8 changes: 2 additions & 6 deletions src/joserfc/rfc7517/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,15 @@ def __init__(self, raw_value: t.Any, original_value: t.Any, parameters: t.Option
self._raw_value = raw_value
self.original_value = original_value
self.extra_parameters = parameters
self._dict_value: KeyDict = {}
if isinstance(original_value, dict):
data = original_value.copy()
data["kty"] = self.key_type
if parameters:
data.update(dict(parameters))

self.validate_dict_key(data)
self._dict_value: t.Optional[KeyDict] = data
else:
self._dict_value = None
self._dict_value = data

def keys(self):
return self.dict_value.keys()
Expand Down Expand Up @@ -288,6 +287,3 @@ def curve_name(self) -> str:
@abstractmethod
def exchange_derive_key(self, key: "CurveKey") -> bytes:
pass


Key = t.Union[SymmetricKey, AsymmetricKey, CurveKey]
23 changes: 14 additions & 9 deletions src/joserfc/rfc7517/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing as t
from .models import Key
from .models import BaseKey
from .types import KeyAny, KeyParameters
from ..util import to_bytes

Expand All @@ -17,15 +17,19 @@ class JWKRegistry:
data = {"kty": "oct", "k": "..."}
key = JWKRegistry.import_key(data)
"""
key_types: t.Dict[str, t.Type[Key]] = {}
key_types: t.Dict[str, t.Type[BaseKey]] = {}
algorithm_key_types: t.Dict[str, t.List[str]] = {}

@classmethod
def register(cls, model: t.Type[Key]):
def register(cls, model: t.Type[BaseKey]):
cls.key_types[model.key_type] = model

@classmethod
def import_key(cls, data: KeyAny, key_type: t.Optional[str] = None, parameters: KeyParameters = None) -> Key:
def import_key(
cls,
data: KeyAny,
key_type: t.Optional[str] = None,
parameters: t.Optional[KeyParameters] = None) -> BaseKey:
"""A class method for importing a key from bytes, string, and dict.
When ``value`` is a dict, this method can tell the key type automatically,
otherwise, developers SHOULD pass the ``key_type`` themselves.
Expand All @@ -36,9 +40,10 @@ def import_key(cls, data: KeyAny, key_type: t.Optional[str] = None, parameters:
:return: OctKey, RSAKey, ECKey, or OKPKey
"""
if isinstance(data, dict) and key_type is None:
if "kty" not in data:
if "kty" in data:
key_type = data["kty"]
else:
raise ValueError("Missing key type")
key_type = data["kty"]

if key_type not in cls.key_types:
raise ValueError(f'Invalid key type: "{key_type}"')
Expand All @@ -54,8 +59,8 @@ def generate_key(
cls,
key_type: str,
crv_or_size: t.Union[str, int],
parameters: KeyParameters = None,
private: bool = True) -> Key:
parameters: t.Optional[KeyParameters] = None,
private: bool = True) -> BaseKey:
"""A class method for generating key according to the given key type.
When ``key_type`` is "oct" and "RSA", the second parameter SHOULD be
a key size in bits. When ``key_type`` is "EC" and "OKP", the second
Expand All @@ -70,4 +75,4 @@ def generate_key(
raise ValueError(f'Invalid key type: "{key_type}"')

key_cls = cls.key_types[key_type]
return key_cls.generate_key(crv_or_size, parameters, private) # type: ignore
return key_cls.generate_key(crv_or_size, parameters, private)

0 comments on commit cd78f6c

Please sign in to comment.