From cb9f19f5c73b5c0ddc68748ba88a2dbeb6f147e4 Mon Sep 17 00:00:00 2001 From: Matt Pollak Date: Wed, 23 Aug 2023 12:56:37 -0400 Subject: [PATCH 1/2] Update jwt-api to accept either a string or list of strings for issuer validation --- jwt/api_jwt.py | 14 +++++++++----- tests/test_api_jwt.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 48d739ad..9d035598 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List from . import api_jws from .exceptions import ( @@ -110,7 +110,7 @@ def decode_complete( # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -195,7 +195,7 @@ def decode( # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -362,8 +362,12 @@ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if payload["iss"] != issuer: - raise InvalidIssuerError("Invalid issuer") + if isinstance(issuer, list): + if payload["iss"] not in issuer: + raise InvalidIssuerError("Invalid issuer") + else: + if payload["iss"] != issuer: + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index b51170b3..0ca93dd1 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -5,7 +5,6 @@ from decimal import Decimal import pytest - from jwt.api_jwt import PyJWT from jwt.exceptions import ( DecodeError, @@ -486,6 +485,12 @@ def test_check_issuer_when_valid(self, jwt): token = jwt.encode(payload, "secret") jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"]) + def test_check_issuer_list_when_valid(self, jwt): + issuer = ["urn:foo", "urn:bar"] + payload = {"some": "payload", "iss": "urn:foo"} + token = jwt.encode(payload, "secret") + jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"]) + def test_raise_exception_invalid_issuer(self, jwt): issuer = "urn:wrong" @@ -496,6 +501,16 @@ def test_raise_exception_invalid_issuer(self, jwt): with pytest.raises(InvalidIssuerError): jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"]) + def test_raise_exception_invalid_issuer_list(self, jwt): + issuer = ["urn:wrong", "urn:bar", "urn:baz"] + + payload = {"some": "payload", "iss": "urn:foo"} + + token = jwt.encode(payload, "secret") + + with pytest.raises(InvalidIssuerError): + jwt.decode(token, "secret", issuer=issuer, algorithms=["HS256"]) + def test_skip_check_audience(self, jwt): payload = {"some": "payload", "aud": "urn:me"} token = jwt.encode(payload, "secret") From f875826f0d965afec76ae8189bcf356dfe45b9d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:03:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_api_jwt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 0ca93dd1..509f788f 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -5,6 +5,7 @@ from decimal import Decimal import pytest + from jwt.api_jwt import PyJWT from jwt.exceptions import ( DecodeError,