Skip to content

Commit

Permalink
Update jwt-api to accept either a string or list of strings for issue…
Browse files Browse the repository at this point in the history
…r validation (#913)

* Update jwt-api to accept either a string or list of strings for issuer validation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mattpollak and pre-commit-ci[bot] authored Sep 3, 2023
1 parent 719a9f5 commit 332774f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
14 changes: 9 additions & 5 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, List

from . import api_jws
from .exceptions import (
Expand Down Expand Up @@ -110,7 +110,7 @@ def decode_complete(
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
issuer: str | List[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
Expand Down Expand Up @@ -195,7 +195,7 @@ def decode(
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
issuer: str | List[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
Expand Down Expand Up @@ -362,8 +362,12 @@ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if "iss" not in payload:
raise MissingRequiredClaimError("iss")

if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")
if isinstance(issuer, list):
if payload["iss"] not in issuer:
raise InvalidIssuerError("Invalid issuer")
else:
if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")


_jwt_global_obj = PyJWT()
Expand Down
16 changes: 16 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ def test_check_issuer_when_valid(self, jwt):
token = jwt.encode(payload, "secret")
jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"])

def test_check_issuer_list_when_valid(self, jwt):
issuer = ["urn:foo", "urn:bar"]
payload = {"some": "payload", "iss": "urn:foo"}
token = jwt.encode(payload, "secret")
jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"])

def test_raise_exception_invalid_issuer(self, jwt):
issuer = "urn:wrong"

Expand All @@ -496,6 +502,16 @@ def test_raise_exception_invalid_issuer(self, jwt):
with pytest.raises(InvalidIssuerError):
jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"])

def test_raise_exception_invalid_issuer_list(self, jwt):
issuer = ["urn:wrong", "urn:bar", "urn:baz"]

payload = {"some": "payload", "iss": "urn:foo"}

token = jwt.encode(payload, "secret")

with pytest.raises(InvalidIssuerError):
jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"])

def test_skip_check_audience(self, jwt):
payload = {"some": "payload", "aud": "urn:me"}
token = jwt.encode(payload, "secret")
Expand Down

0 comments on commit 332774f

Please sign in to comment.