Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complete types to take all allowed keys into account #873

Merged
merged 20 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 110 additions & 72 deletions jwt/algorithms.py

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import json
import time
from typing import Any, Optional
from typing import Any

from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
from .types import JWKDict


class PyJWK:
def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
self.key = self.Algorithm.from_jwk(self._jwk_data)

@staticmethod
def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
return PyJWK(obj, algorithm)

@staticmethod
Expand All @@ -69,15 +69,15 @@ def from_json(data: str, algorithm: None = None) -> "PyJWK":
return PyJWK.from_dict(obj, algorithm)

@property
def key_type(self) -> Optional[str]:
def key_type(self) -> str | None:
return self._jwk_data.get("kty", None)

@property
def key_id(self) -> Optional[str]:
def key_id(self) -> str | None:
return self._jwk_data.get("kid", None)

@property
def public_key_use(self) -> Optional[str]:
def public_key_use(self) -> str | None:
return self._jwk_data.get("use", None)


Expand Down
23 changes: 13 additions & 10 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import binascii
import json
import warnings
from typing import Any, List, Optional, Type
from typing import TYPE_CHECKING, Any

from .algorithms import (
Algorithm,
Expand All @@ -20,14 +20,17 @@
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning

if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys


class PyJWS:
header_typ = "JWT"

def __init__(
self,
algorithms: Optional[List[str]] = None,
options: Optional[dict[str, Any]] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
Expand Down Expand Up @@ -100,10 +103,10 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
def encode(
self,
payload: bytes,
key: str,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
Expand Down Expand Up @@ -169,7 +172,7 @@ def encode(
def decode_complete(
self,
jwt: str | bytes,
key: str | bytes = "",
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
Expand Down Expand Up @@ -214,7 +217,7 @@ def decode_complete(
def decode(
self,
jwt: str | bytes,
key: str | bytes = "",
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
Expand Down Expand Up @@ -286,7 +289,7 @@ def _verify_signature(
signing_input: bytes,
header: dict[str, Any],
signature: bytes,
key: str | bytes = "",
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
try:
Expand All @@ -301,9 +304,9 @@ def _verify_signature(
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
key = alg_obj.prepare_key(key)
prepared_key = alg_obj.prepare_key(key)

if not alg_obj.verify(signing_input, key, signature):
if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed")

def _validate_headers(self, headers: dict[str, Any]) -> None:
Expand Down
67 changes: 35 additions & 32 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any

from . import api_jws
from .exceptions import (
Expand All @@ -19,15 +19,18 @@
)
from .warnings import RemovedInPyjwt3Warning

if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys


class PyJWT:
def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
self.options: dict[str, Any] = {**self._get_default_options(), **options}

@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
def _get_default_options() -> dict[str, bool | list[str]]:
return {
"verify_signature": True,
"verify_exp": True,
Expand All @@ -40,11 +43,11 @@ def _get_default_options() -> Dict[str, Union[bool, List[str]]]:

def encode(
self,
payload: Dict[str, Any],
key: str,
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str:
# Check that we get a dict
Expand Down Expand Up @@ -78,9 +81,9 @@ def encode(

def _encode_payload(
self,
payload: Dict[str, Any],
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
payload: dict[str, Any],
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
Expand All @@ -97,21 +100,21 @@ def _encode_payload(
def decode_complete(
self,
jwt: str | bytes,
key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs,
) -> Dict[str, Any]:
**kwargs: Any,
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
Expand Down Expand Up @@ -163,7 +166,7 @@ def decode_complete(
decoded["payload"] = payload
return decoded

def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).

Expand All @@ -182,20 +185,20 @@ def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
def decode(
self,
jwt: str | bytes,
key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs,
**kwargs: Any,
) -> Any:
if kwargs:
warnings.warn(
Expand Down Expand Up @@ -303,7 +306,7 @@ def _validate_exp(
def _validate_aud(
self,
payload: dict[str, Any],
audience: Optional[Union[str, Iterable[str]]],
audience: str | Iterable[str] | None,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
Expand Down
4 changes: 2 additions & 2 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import binascii
import re
from typing import Any, AnyStr
from typing import AnyStr

try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
Expand All @@ -13,7 +13,7 @@
pass


def force_bytes(value: Any) -> bytes:
def force_bytes(value: AnyStr) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ source = ["jwt", ".tox/*/site-packages"]

[tool.coverage.report]
show_missing = true
exclude_lines = ["if TYPE_CHECKING:"]

[tool.isort]
profile = "black"
Expand Down
6 changes: 0 additions & 6 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,12 +767,6 @@ def test_okp_ed25519_should_reject_non_string_key(self):
with open(key_path("testkey_ed25519.pub")) as keyfile:
algo.prepare_key(keyfile.read())

def test_okp_ed25519_should_accept_unicode_key(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please describe why this test is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a duplicate of

def test_okp_ed25519_should_reject_non_string_key(self):
algo = OKPAlgorithm()
with pytest.raises(InvalidKeyError):
algo.prepare_key(None)
with open(key_path("testkey_ed25519")) as keyfile:
algo.prepare_key(keyfile.read())

L764-765

In fact I think it's better to move the two valid checks from test_okp_ed25519_should_reject_non_string_key to this removed test_okp_ed25519_should_accept_unicode_key. Tell me if you want me to do it this way

algo = OKPAlgorithm()

with open(key_path("testkey_ed25519")) as ec_key:
algo.prepare_key(ec_key.read())

def test_okp_ed25519_sign_should_generate_correct_signature_value(self):
algo = OKPAlgorithm()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_non_object_options_dont_persist(self, jws, payload):

assert jws.options["verify_signature"]

def test_options_must_be_dict(self, jws):
def test_options_must_be_dict(self):
pytest.raises(TypeError, PyJWS, options=object())
pytest.raises((TypeError, ValueError), PyJWS, options=("something"))

Expand Down