Skip to content

Commit

Permalink
feat: callable key can return str, bytes, Key, and KeySet
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Oct 7, 2023
1 parent 71be09b commit 43be0a2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
38 changes: 21 additions & 17 deletions src/joserfc/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def set_kid(self, kid: str) -> None:
...


KeyCallable = t.Callable[[GuestProtocol], Key]
KeyFlexible = t.Union[str, bytes, Key, KeySet, KeyCallable]
KeyBase = t.Union[str, bytes, Key, KeySet]
KeyCallable = t.Callable[[GuestProtocol], KeyBase]
KeyFlexible = t.Union[KeyBase, KeyCallable]


def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) -> Key:
Expand All @@ -47,32 +48,35 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) ->
:param obj: a protocol that has ``headers`` and ``set_kid`` methods
:param use_random: pick a random key from key set
"""
headers = obj.headers()

rv_key: Key
if isinstance(key, (str, bytes)):
rv_key = OctKey.import_key(key)

elif isinstance(key, (OctKey, RSAKey, ECKey, OKPKey)):
rv_key = key
_norm_key: t.Union[Key, KeySet]
if callable(key):
_norm_key = _normalize_key(key(obj))
else:
_norm_key = _normalize_key(key)

elif isinstance(key, KeySet):
rv_key: Key
if isinstance(_norm_key, KeySet):
headers = obj.headers()
kid = headers.get("kid")
if not kid and use_random:
# choose one key by random
rv_key = key.pick_random_key(headers["alg"]) # type: ignore[assignment]
rv_key = _norm_key.pick_random_key(headers["alg"]) # type: ignore[assignment]
if rv_key is None:
raise ValueError("Invalid key")
rv_key.ensure_kid()
assert rv_key.kid is not None # for mypy
obj.set_kid(rv_key.kid)
else:
rv_key = key.get_by_kid(kid)

elif callable(key):
rv_key = key(obj)

rv_key = _norm_key.get_by_kid(kid)
elif isinstance(_norm_key, (OctKey, RSAKey, ECKey, OKPKey)):
rv_key = _norm_key
else:
raise ValueError("Invalid key")

return rv_key


def _normalize_key(key: KeyBase) -> t.Union[Key, KeySet]:
if isinstance(key, (str, bytes)):
return OctKey.import_key(key)
return key
23 changes: 20 additions & 3 deletions tests/jwk/test_key_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,29 @@ def test_guess_bytes_key(self):
self.assertIsInstance(key, OctKey)

def test_guess_callable_key(self):
def key_func(obj):
return OctKey.import_key("key")
oct_key = OctKey.generate_key(parameters={'kid': '1'})
rsa_key = RSAKey.generate_key(parameters={'kid': '2'})

key = guess_key(key_func, Guest())
def key_func1(obj):
return "key"

def key_func2(obj):
return rsa_key

def key_func3(obj):
return KeySet([oct_key, rsa_key])

key = guess_key(key_func1, Guest())
self.assertIsInstance(key, OctKey)

key = guess_key(key_func2, Guest())
self.assertIsInstance(key, RSAKey)

guest = Guest()
guest.set_kid('2')
key = guess_key(key_func3, guest)
self.assertIsInstance(key, RSAKey)

def test_guess_key_set(self):
key_set = KeySet([OctKey.generate_key(), RSAKey.generate_key()])
guest = Guest()
Expand Down

0 comments on commit 43be0a2

Please sign in to comment.