Skip to content

Commit

Permalink
Fix issue with nbf claim (#177)
Browse files Browse the repository at this point in the history
Fix issue with nbf claim not being honored + tests
  • Loading branch information
lannuttia authored Aug 26, 2023
1 parent 5d83c86 commit 28d7a55
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.11.5
hooks:
- id: isort
args:
Expand All @@ -29,7 +29,7 @@ repos:
- --follow-imports
- skip
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.0
rev: v2.1.1
hooks:
- id: autoflake
args:
Expand Down
15 changes: 13 additions & 2 deletions starlette_authlib/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from collections import namedtuple

from authlib.jose import jwt
from authlib.jose.errors import BadSignatureError, DecodeError, ExpiredTokenError
from authlib.jose.errors import (
BadSignatureError,
DecodeError,
ExpiredTokenError,
InvalidTokenError,
)
from starlette.config import Config
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
Expand Down Expand Up @@ -86,9 +91,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
),
)
jwt_payload.validate_exp(time.time(), 0)
jwt_payload.validate_nbf(time.time(), 0)
scope["session"] = jwt_payload
initial_session_was_empty = False
except (BadSignatureError, ExpiredTokenError, DecodeError):
except (
BadSignatureError,
ExpiredTokenError,
DecodeError,
InvalidTokenError,
):
scope["session"] = {}
else:
scope["session"] = {}
Expand Down
75 changes: 75 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from datetime import datetime, timedelta

import pytest
from starlette.applications import Starlette
Expand Down Expand Up @@ -121,6 +122,80 @@ def test_session_expires():
assert response.json() == {"session": {}}


def test_session_futue_nbf():
now = datetime.now()
nbf = datetime.timestamp(now + timedelta(days=1))
claims = {"nbf": nbf, "some": "data"}
for jwt_alg, secret_key in (
("HS256", "example"),
(
"RS256",
SecretKey(
Secret(open(os.path.join(KEYS_DIR, "rsa.key")).read()),
Secret(open(os.path.join(KEYS_DIR, "rsa.pub")).read()),
),
),
):
app = create_app()
app.add_middleware(
SessionMiddleware, jwt_alg=jwt_alg, secret_key=secret_key, https_only=True
)
secure_client = TestClient(app, base_url="https://testserver")

response = secure_client.get("/view_session")
assert response.json() == {"session": {}}

response = secure_client.post("/update_session", json=claims.copy())
assert response.json() == {"session": claims.copy()}

response = secure_client.get("/view_session").json()
assert response == {"session": {}}

response = secure_client.post("/clear_session")
assert response.json() == {"session": {}}

response = secure_client.get("/view_session")
assert response.json() == {"session": {}}


def test_session_past_nbf():
now = datetime.now()
nbf = datetime.timestamp(now - timedelta(seconds=1))
claims = {"nbf": nbf, "some": "data"}
for jwt_alg, secret_key in (
("HS256", "example"),
(
"RS256",
SecretKey(
Secret(open(os.path.join(KEYS_DIR, "rsa.key")).read()),
Secret(open(os.path.join(KEYS_DIR, "rsa.pub")).read()),
),
),
):
app = create_app()
app.add_middleware(
SessionMiddleware, jwt_alg=jwt_alg, secret_key=secret_key, https_only=True
)
secure_client = TestClient(app, base_url="https://testserver")

response = secure_client.get("/view_session")
assert response.json() == {"session": {}}

response = secure_client.post("/update_session", json=claims.copy())
assert response.json() == {"session": claims.copy()}

response = secure_client.get("/view_session").json()
assert "exp" in response["session"]
del response["session"]["exp"]
assert response == {"session": claims.copy()}

response = secure_client.post("/clear_session")
assert response.json() == {"session": {}}

response = secure_client.get("/view_session")
assert response.json() == {"session": {}}


def test_secure_session():
for jwt_alg, secret_key in (
("HS256", "example"),
Expand Down

0 comments on commit 28d7a55

Please sign in to comment.