Skip to content

Commit

Permalink
fix: Suggested changes and test with static analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
KShivendu committed Jul 24, 2023
1 parent 9cc0daf commit b90b4dc
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 16 deletions.
3 changes: 1 addition & 2 deletions supertokens_python/framework/django/django_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from asgiref.sync import async_to_sync

from supertokens_python.utils import default_user_context


def middleware(get_response: Any):
from supertokens_python import Supertokens
Expand All @@ -30,6 +28,7 @@ def middleware(get_response: Any):
from supertokens_python.supertokens import manage_session_post_response

from django.http import HttpRequest
from supertokens_python.utils import default_user_context

if asyncio.iscoroutinefunction(get_response):

Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/framework/fastapi/fastapi_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from typing import TYPE_CHECKING, Union

from supertokens_python.framework import BaseResponse
from supertokens_python.utils import default_user_context

if TYPE_CHECKING:
from fastapi import FastAPI, Request


def get_middleware():
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from supertokens_python.utils import default_user_context

class Middleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI):
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/framework/flask/flask_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from supertokens_python.async_to_sync_wrapper import sync
from supertokens_python.framework import BaseResponse
from supertokens_python.utils import default_user_context

if TYPE_CHECKING:
from flask import Flask
Expand All @@ -35,6 +34,7 @@ def set_before_after_request(self):
from supertokens_python.framework.flask.flask_request import FlaskRequest
from supertokens_python.framework.flask.flask_response import FlaskResponse
from supertokens_python.supertokens import manage_session_post_response
from supertokens_python.utils import default_user_context

from flask.wrappers import Response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from supertokens_python.framework.django.django_response import DjangoResponse
from supertokens_python.recipe.session import SessionContainer, SessionRecipe
from supertokens_python.recipe.session.interfaces import SessionClaimValidator
from supertokens_python.utils import default_user_context
from supertokens_python.utils import set_request_in_user_context_if_not_defined
from supertokens_python.types import MaybeAwaitable

_T = TypeVar("_T", bound=Callable[..., Any])
Expand Down Expand Up @@ -50,8 +50,7 @@ async def wrapped_function(request: HttpRequest, *args: Any, **kwargs: Any):

try:
baseRequest = DjangoRequest(request)
if user_context is None:
user_context = default_user_context(baseRequest)
user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest)

recipe = SessionRecipe.get_instance()
session = await recipe.verify_session(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from supertokens_python.framework.django.django_response import DjangoResponse
from supertokens_python.recipe.session import SessionRecipe, SessionContainer
from supertokens_python.recipe.session.interfaces import SessionClaimValidator
from supertokens_python.utils import default_user_context
from supertokens_python.utils import set_request_in_user_context_if_not_defined
from supertokens_python.types import MaybeAwaitable

_T = TypeVar("_T", bound=Callable[..., Any])
Expand Down Expand Up @@ -51,8 +51,7 @@ def wrapped_function(request: HttpRequest, *args: Any, **kwargs: Any):

try:
baseRequest = DjangoRequest(request)
if user_context is None:
user_context = default_user_context(baseRequest)
user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest)

recipe = SessionRecipe.get_instance()
session = sync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from supertokens_python.types import MaybeAwaitable

from ...interfaces import SessionContainer, SessionClaimValidator
from supertokens_python.utils import default_user_context
from supertokens_python.utils import set_request_in_user_context_if_not_defined


def verify_session(
Expand All @@ -39,8 +39,7 @@ def verify_session(
async def func(request: Request) -> Union[SessionContainer, None]:
nonlocal user_context
baseRequest = FastApiRequest(request)
if user_context is None:
user_context = default_user_context(baseRequest)
user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest)

recipe = SessionRecipe.get_instance()
session = await recipe.verify_session(
Expand Down
5 changes: 2 additions & 3 deletions supertokens_python/recipe/session/framework/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from supertokens_python.framework.flask.flask_request import FlaskRequest
from supertokens_python.recipe.session import SessionRecipe, SessionContainer
from supertokens_python.recipe.session.interfaces import SessionClaimValidator
from supertokens_python.utils import default_user_context
from supertokens_python.utils import set_request_in_user_context_if_not_defined
from supertokens_python.types import MaybeAwaitable

_T = TypeVar("_T", bound=Callable[..., Any])
Expand All @@ -45,8 +45,7 @@ def wrapped_function(*args: Any, **kwargs: Any):
from flask import make_response, request

baseRequest = FlaskRequest(request)
if user_context is None:
user_context = default_user_context(baseRequest)
user_context = set_request_in_user_context_if_not_defined(user_context, baseRequest)

recipe = SessionRecipe.get_instance()
session = sync(
Expand Down
15 changes: 15 additions & 0 deletions tests/test_user_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from pathlib import Path

from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -400,3 +401,17 @@ async def create_new_session(
create_new_session_context_works,
]
)




async def test_default_user_context_func_calls():
# Tests run in the root directory of the repo
root_dir = Path("supertokens_python")
file_occurences: List[str] = []
for path in root_dir.rglob("*.py"):
with open(path) as f:
file_occurences.extend([str(path)] * f.read().count("user_context = set_request_in_user_context_if_not_defined("))
file_occurences.extend([str(path)] * f.read().count("user_context = default_user_context("))

assert len(file_occurences) == 7

0 comments on commit b90b4dc

Please sign in to comment.