Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make mypy configuration stricter and improve typing #830

Merged
merged 14 commits into from
Dec 10, 2022
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sphinx_rtd_theme


def read(*parts):
def read(*parts) -> str:
"""
Build an absolute path from *parts* and and return the contents of the
resulting file. Assume UTF-8 encoding.
Expand All @@ -14,7 +14,7 @@ def read(*parts):
return f.read()


def find_version(*file_paths):
def find_version(*file_paths) -> str:
"""
Build a path from *file_paths* and search for a ``__version__``
string inside.
Expand Down
99 changes: 57 additions & 42 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import hashlib
import hmac
import json
from typing import Any, Dict, Union

from .exceptions import InvalidKeyError
from .types import JWKDict
from .utils import (
base64url_decode,
base64url_encode,
Expand All @@ -20,10 +22,18 @@
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurve,
EllipticCurvePrivateKey,
EllipticCurvePrivateNumbers,
EllipticCurvePublicKey,
EllipticCurvePublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey,
Expand Down Expand Up @@ -73,7 +83,7 @@
}


def get_default_algorithms():
def get_default_algorithms() -> Dict[str, "Algorithm"]:
"""
Returns the algorithms that are implemented by the library.
"""
Expand Down Expand Up @@ -130,40 +140,44 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return digest.finalize()
return bytes(digest.finalize())
else:
return hash_alg(bytestr).digest()
return bytes(hash_alg(bytestr).digest())

def prepare_key(self, key):
# TODO: all key-related `Any`s in this class should optimally be made
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.

def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError

def sign(self, msg, key):
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError

def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
raise NotImplementedError

@staticmethod
def to_jwk(key_obj):
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
def from_jwk(jwk):
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
Expand Down Expand Up @@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm):
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512

def __init__(self, hash_alg):
def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key):
Expand Down Expand Up @@ -242,7 +256,7 @@ def from_jwk(jwk):

return base64url_decode(obj["k"])

def sign(self, msg, key):
def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()

def verify(self, msg, key, sig):
Expand All @@ -261,7 +275,7 @@ class RSAAlgorithm(Algorithm):
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512

def __init__(self, hash_alg):
def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key):
Expand All @@ -271,16 +285,15 @@ def prepare_key(self, key):
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")

key = force_bytes(key)
key_bytes = force_bytes(key)

try:
if key.startswith(b"ssh-rsa"):
key = load_ssh_public_key(key)
if key_bytes.startswith(b"ssh-rsa"):
return load_ssh_public_key(key_bytes)
else:
key = load_pem_private_key(key, password=None)
return load_pem_private_key(key_bytes, password=None)
except ValueError:
key = load_pem_public_key(key)
return key
return load_pem_public_key(key_bytes)

@staticmethod
def to_jwk(key_obj):
Expand Down Expand Up @@ -383,12 +396,10 @@ def from_jwk(jwk):
return numbers.private_key()
elif "n" in obj and "e" in obj:
# Public key
numbers = RSAPublicNumbers(
return RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
)

return numbers.public_key()
).public_key()
else:
raise InvalidKeyError("Not a public or private key")

Expand All @@ -412,7 +423,7 @@ class ECAlgorithm(Algorithm):
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512

def __init__(self, hash_alg):
def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key):
Expand All @@ -422,18 +433,18 @@ def prepare_key(self, key):
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")

key = force_bytes(key)
key_bytes = force_bytes(key)

# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
if key.startswith(b"ecdsa-sha2-"):
key = load_ssh_public_key(key)
if key_bytes.startswith(b"ecdsa-sha2-"):
key = load_ssh_public_key(key_bytes)
else:
key = load_pem_public_key(key)
key = load_pem_public_key(key_bytes)
except ValueError:
key = load_pem_private_key(key, password=None)
key = load_pem_private_key(key_bytes, password=None)

# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
Expand All @@ -444,7 +455,7 @@ def prepare_key(self, key):
return key

def sign(self, msg, key):
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
der_sig = key.sign(msg, ECDSA(self.hash_alg()))

return der_to_raw_signature(der_sig, key.curve)

Expand All @@ -457,7 +468,7 @@ def verify(self, msg, key, sig):
try:
if isinstance(key, EllipticCurvePrivateKey):
key = key.public_key()
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
Expand All @@ -472,13 +483,13 @@ def to_jwk(key_obj):
else:
raise InvalidKeyError("Not a public or private key")

if isinstance(key_obj.curve, ec.SECP256R1):
if isinstance(key_obj.curve, SECP256R1):
crv = "P-256"
elif isinstance(key_obj.curve, ec.SECP384R1):
elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384"
elif isinstance(key_obj.curve, ec.SECP521R1):
elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521"
elif isinstance(key_obj.curve, ec.SECP256K1):
elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
Expand All @@ -498,7 +509,9 @@ def to_jwk(key_obj):
return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
def from_jwk(
jwk: Any,
) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
Expand All @@ -519,32 +532,34 @@ def from_jwk(jwk):
y = base64url_decode(obj.get("y"))

curve = obj.get("crv")
curve_obj: EllipticCurve

if curve == "P-256":
if len(x) == len(y) == 32:
curve_obj = ec.SECP256R1()
curve_obj = SECP256R1()
else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
curve_obj = ec.SECP384R1()
curve_obj = SECP384R1()
else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
curve_obj = ec.SECP521R1()
curve_obj = SECP521R1()
else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1":
if len(x) == len(y) == 32:
curve_obj = ec.SECP256K1()
curve_obj = SECP256K1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1"
)
else:
raise InvalidKeyError(f"Invalid curve: {curve}")

public_numbers = ec.EllipticCurvePublicNumbers(
public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve_obj,
Expand All @@ -559,7 +574,7 @@ def from_jwk(jwk):
"D should be {} bytes for curve {}", len(x), curve
)

return ec.EllipticCurvePrivateNumbers(
return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers
).private_key()

Expand Down Expand Up @@ -600,7 +615,7 @@ class OKPAlgorithm(Algorithm):
This class requires ``cryptography>=2.6`` to be installed.
"""

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also type kwargs with kwargs: Any.

pass

def prepare_key(self, key):
Expand Down
26 changes: 14 additions & 12 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import json
import time
from typing import Any, Optional

from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
from .types import JWKDict


class PyJWK:
def __init__(self, jwk_data, algorithm=None):
def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data

Expand Down Expand Up @@ -55,29 +57,29 @@ def __init__(self, jwk_data, algorithm=None):
self.key = self.Algorithm.from_jwk(self._jwk_data)

@staticmethod
def from_dict(obj, algorithm=None):
def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
return PyJWK(obj, algorithm)

@staticmethod
def from_json(data, algorithm=None):
def from_json(data: str, algorithm: None = None) -> "PyJWK":
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)

@property
def key_type(self):
def key_type(self) -> str:
return self._jwk_data.get("kty", None)

@property
def key_id(self):
def key_id(self) -> str:
return self._jwk_data.get("kid", None)

@property
def public_key_use(self):
def public_key_use(self) -> Optional[str]:
return self._jwk_data.get("use", None)


class PyJWKSet:
def __init__(self, keys: list[dict]) -> None:
def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []

if not keys:
Expand All @@ -97,16 +99,16 @@ def __init__(self, keys: list[dict]) -> None:
raise PyJWKSetError("The JWK Set did not contain any usable keys")

@staticmethod
def from_dict(obj):
def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
keys = obj.get("keys", [])
return PyJWKSet(keys)

@staticmethod
def from_json(data):
def from_json(data: str) -> "PyJWKSet":
obj = json.loads(data)
return PyJWKSet.from_dict(obj)

def __getitem__(self, kid):
def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys:
if key.key_id == kid:
return key
Expand All @@ -118,8 +120,8 @@ def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()

def get_jwk_set(self):
def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set

def get_timestamp(self):
def get_timestamp(self) -> float:
return self.timestamp
Loading