From b90b4dcecada3c4827156d4f05f87c703f87986b Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 24 Jul 2023 12:08:28 +0530 Subject: [PATCH] fix: Suggested changes and test with static analysis --- .../framework/django/django_middleware.py | 3 +-- .../framework/fastapi/fastapi_middleware.py | 2 +- .../framework/flask/flask_middleware.py | 2 +- .../session/framework/django/asyncio/__init__.py | 5 ++--- .../session/framework/django/syncio/__init__.py | 5 ++--- .../recipe/session/framework/fastapi/__init__.py | 5 ++--- .../recipe/session/framework/flask/__init__.py | 5 ++--- tests/test_user_context.py | 15 +++++++++++++++ 8 files changed, 26 insertions(+), 16 deletions(-) diff --git a/supertokens_python/framework/django/django_middleware.py b/supertokens_python/framework/django/django_middleware.py index f4d26ce73..d35fcac47 100644 --- a/supertokens_python/framework/django/django_middleware.py +++ b/supertokens_python/framework/django/django_middleware.py @@ -18,8 +18,6 @@ from asgiref.sync import async_to_sync -from supertokens_python.utils import default_user_context - def middleware(get_response: Any): from supertokens_python import Supertokens @@ -30,6 +28,7 @@ def middleware(get_response: Any): from supertokens_python.supertokens import manage_session_post_response from django.http import HttpRequest + from supertokens_python.utils import default_user_context if asyncio.iscoroutinefunction(get_response): diff --git a/supertokens_python/framework/fastapi/fastapi_middleware.py b/supertokens_python/framework/fastapi/fastapi_middleware.py index cc473f0ce..8ee6a9885 100644 --- a/supertokens_python/framework/fastapi/fastapi_middleware.py +++ b/supertokens_python/framework/fastapi/fastapi_middleware.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Union from supertokens_python.framework import BaseResponse -from supertokens_python.utils import default_user_context if TYPE_CHECKING: from fastapi import FastAPI, Request @@ -24,6 +23,7 @@ def get_middleware(): from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint + from supertokens_python.utils import default_user_context class Middleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): diff --git a/supertokens_python/framework/flask/flask_middleware.py b/supertokens_python/framework/flask/flask_middleware.py index fcb1f7917..ed3479dd1 100644 --- a/supertokens_python/framework/flask/flask_middleware.py +++ b/supertokens_python/framework/flask/flask_middleware.py @@ -18,7 +18,6 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.framework import BaseResponse -from supertokens_python.utils import default_user_context if TYPE_CHECKING: from flask import Flask @@ -35,6 +34,7 @@ def set_before_after_request(self): from supertokens_python.framework.flask.flask_request import FlaskRequest from supertokens_python.framework.flask.flask_response import FlaskResponse from supertokens_python.supertokens import manage_session_post_response + from supertokens_python.utils import default_user_context from flask.wrappers import Response diff --git a/supertokens_python/recipe/session/framework/django/asyncio/__init__.py b/supertokens_python/recipe/session/framework/django/asyncio/__init__.py index d98bf9beb..9451e7b70 100644 --- a/supertokens_python/recipe/session/framework/django/asyncio/__init__.py +++ b/supertokens_python/recipe/session/framework/django/asyncio/__init__.py @@ -20,7 +20,7 @@ from supertokens_python.framework.django.django_response import DjangoResponse from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.session.interfaces import SessionClaimValidator -from supertokens_python.utils import default_user_context +from supertokens_python.utils import set_request_in_user_context_if_not_defined from supertokens_python.types import MaybeAwaitable _T = TypeVar("_T", bound=Callable[..., Any]) @@ -50,8 +50,7 @@ async def wrapped_function(request: HttpRequest, *args: Any, **kwargs: Any): try: baseRequest = DjangoRequest(request) - if user_context is None: - user_context = default_user_context(baseRequest) + user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest) recipe = SessionRecipe.get_instance() session = await recipe.verify_session( diff --git a/supertokens_python/recipe/session/framework/django/syncio/__init__.py b/supertokens_python/recipe/session/framework/django/syncio/__init__.py index 13715178a..245dd82d9 100644 --- a/supertokens_python/recipe/session/framework/django/syncio/__init__.py +++ b/supertokens_python/recipe/session/framework/django/syncio/__init__.py @@ -21,7 +21,7 @@ from supertokens_python.framework.django.django_response import DjangoResponse from supertokens_python.recipe.session import SessionRecipe, SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator -from supertokens_python.utils import default_user_context +from supertokens_python.utils import set_request_in_user_context_if_not_defined from supertokens_python.types import MaybeAwaitable _T = TypeVar("_T", bound=Callable[..., Any]) @@ -51,8 +51,7 @@ def wrapped_function(request: HttpRequest, *args: Any, **kwargs: Any): try: baseRequest = DjangoRequest(request) - if user_context is None: - user_context = default_user_context(baseRequest) + user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest) recipe = SessionRecipe.get_instance() session = sync( diff --git a/supertokens_python/recipe/session/framework/fastapi/__init__.py b/supertokens_python/recipe/session/framework/fastapi/__init__.py index c561d87f9..97b22873e 100644 --- a/supertokens_python/recipe/session/framework/fastapi/__init__.py +++ b/supertokens_python/recipe/session/framework/fastapi/__init__.py @@ -18,7 +18,7 @@ from supertokens_python.types import MaybeAwaitable from ...interfaces import SessionContainer, SessionClaimValidator -from supertokens_python.utils import default_user_context +from supertokens_python.utils import set_request_in_user_context_if_not_defined def verify_session( @@ -39,8 +39,7 @@ def verify_session( async def func(request: Request) -> Union[SessionContainer, None]: nonlocal user_context baseRequest = FastApiRequest(request) - if user_context is None: - user_context = default_user_context(baseRequest) + user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest) recipe = SessionRecipe.get_instance() session = await recipe.verify_session( diff --git a/supertokens_python/recipe/session/framework/flask/__init__.py b/supertokens_python/recipe/session/framework/flask/__init__.py index bc9211d91..4ca58f074 100644 --- a/supertokens_python/recipe/session/framework/flask/__init__.py +++ b/supertokens_python/recipe/session/framework/flask/__init__.py @@ -18,7 +18,7 @@ from supertokens_python.framework.flask.flask_request import FlaskRequest from supertokens_python.recipe.session import SessionRecipe, SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator -from supertokens_python.utils import default_user_context +from supertokens_python.utils import set_request_in_user_context_if_not_defined from supertokens_python.types import MaybeAwaitable _T = TypeVar("_T", bound=Callable[..., Any]) @@ -45,8 +45,7 @@ def wrapped_function(*args: Any, **kwargs: Any): from flask import make_response, request baseRequest = FlaskRequest(request) - if user_context is None: - user_context = default_user_context(baseRequest) + user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest) recipe = SessionRecipe.get_instance() session = sync( diff --git a/tests/test_user_context.py b/tests/test_user_context.py index dd2e36f15..157effba1 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict, List, Optional +from pathlib import Path from fastapi import FastAPI from fastapi.testclient import TestClient @@ -400,3 +401,17 @@ async def create_new_session( create_new_session_context_works, ] ) + + + + +async def test_default_user_context_func_calls(): + # Tests run in the root directory of the repo + root_dir = Path("supertokens_python") + file_occurences: List[str] = [] + for path in root_dir.rglob("*.py"): + with open(path) as f: + file_occurences.extend([str(path)] * f.read().count("user_context = set_request_in_user_context_if_not_defined(")) + file_occurences.extend([str(path)] * f.read().count("user_context = default_user_context(")) + + assert len(file_occurences) == 7