Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate 'sub' and 'jti' claims for the token #991

Closed
wants to merge 8 commits into from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var/
*.egg-info/
.installed.cfg
*.egg

.idea/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changed
jwt.encode({"payload":"abc"}, key=None, algorithm='none')
```

- Added validation for 'sub' (subject) and 'jti' (JWT ID) claims in tokens

Fixed
~~~~~

Expand Down
56 changes: 55 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning
Expand All @@ -39,6 +41,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"require": [],
}

Expand Down Expand Up @@ -112,6 +116,7 @@ def decode_complete(
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | Sequence[str] | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
Expand Down Expand Up @@ -145,6 +150,8 @@ def decode_complete(
options.setdefault("verify_iat", False)
options.setdefault("verify_aud", False)
options.setdefault("verify_iss", False)
options.setdefault("verify_sub", False)
options.setdefault("verify_jti", False)

decoded = api_jws.decode_complete(
jwt,
Expand All @@ -158,7 +165,12 @@ def decode_complete(

merged_options = {**self.options, **options}
self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
payload,
merged_options,
audience=audience,
issuer=issuer,
leeway=leeway,
subject=subject,
)

decoded["payload"] = payload
Expand Down Expand Up @@ -193,6 +205,7 @@ def decode(
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
subject: str | None = None,
issuer: str | Sequence[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
Expand All @@ -215,6 +228,7 @@ def decode(
detached_payload=detached_payload,
audience=audience,
issuer=issuer,
subject=subject,
leeway=leeway,
)
return decoded["payload"]
Expand All @@ -225,6 +239,7 @@ def _validate_claims(
options: dict[str, Any],
audience=None,
issuer=None,
subject: str | None = None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
Expand Down Expand Up @@ -254,6 +269,12 @@ def _validate_claims(
payload, audience, strict=options.get("strict_aud", False)
)

if options["verify_sub"]:
self._validate_sub(payload, subject)

if options["verify_jti"]:
self._validate_jti(payload)

def _validate_required_claims(
self,
payload: dict[str, Any],
Expand All @@ -263,6 +284,39 @@ def _validate_required_claims(
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)

def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
"""
Checks whether "sub" if in the payload is valid ot not.
This is an Optional claim

:param payload(dict): The payload which needs to be validated
:param subject(str): The subject of the token
"""

if "sub" not in payload:
return

if not isinstance(payload["sub"], str):
raise InvalidSubjectError("Subject must be a string")

if subject is not None:
if payload.get("sub") != subject:
raise InvalidSubjectError("Invalid subject")

def _validate_jti(self, payload: dict[str, Any]) -> None:
"""
Checks whether "jti" if in the payload is valid ot not
This is an Optional claim

:param payload(dict): The payload which needs to be validated
"""

if "jti" not in payload:
return

if not isinstance(payload.get("jti"), str):
raise InvalidJTIError("JWT ID must be a string")

def _validate_iat(
self,
payload: dict[str, Any],
Expand Down
8 changes: 8 additions & 0 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class InvalidAlgorithmError(InvalidTokenError):
pass


class InvalidSubjectError(InvalidTokenError):
pass


class InvalidJTIError(InvalidTokenError):
pass


class MissingRequiredClaimError(InvalidTokenError):
def __init__(self, claim: str) -> None:
self.claim = claim
Expand Down
120 changes: 120 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from jwt.utils import base64url_decode
Expand Down Expand Up @@ -816,3 +818,121 @@ def test_decode_strict_ok(self, jwt, payload):
options={"strict_aud": True},
algorithms=["HS256"],
)

# -------------------- Sub Claim Tests --------------------

def test_encode_decode_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["sub"] == "user123"

def test_decode_without_and_not_required_sub_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert "sub" not in decoded

def test_decode_missing_sub_but_required_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError):
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["sub"]}
)

def test_decode_invalid_int_sub_claim(self, jwt):
payload = {
"sub": 1224344,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError):
jwt.decode(token, secret, algorithms=["HS256"])

def test_decode_with_valid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject="user123")

assert decoded["sub"] == "user123"

def test_decode_with_invalid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError) as exc_info:
jwt.decode(token, secret, algorithms=["HS256"], subject="user456")

assert "Invalid subject" in str(exc_info.value)

def test_decode_with_sub_claim_and_none_subject(self, jwt):
payload = {
"sub": "user789",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject=None)
assert decoded["sub"] == "user789"

# -------------------- JTI Claim Tests --------------------

def test_encode_decode_with_valid_jti_claim(self, jwt):
payload = {
"jti": "unique-id-456",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["jti"] == "unique-id-456"

def test_decode_missing_jti_when_required_claim(self, jwt):
payload = {"name": "Bob", "admin": False}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError) as exc_info:
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["jti"]}
)

assert "jti" in str(exc_info.value)

def test_decode_missing_jti_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded.get("jti") is None

def test_jti_claim_with_invalid_int_value(self, jwt):
special_jti = 12223
payload = {
"jti": special_jti,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidJTIError):
jwt.decode(token, secret, algorithms=["HS256"])