diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 432b9a428..d850d304c 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -13,11 +13,12 @@ # under the License. from __future__ import annotations +import re from os import environ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Union from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.recipe_module import APIHandled, RecipeModule, ApiIdWithTenantId from .api import ( api_key_protector, @@ -45,6 +46,7 @@ from .exceptions import SuperTokensDashboardError from .interfaces import APIInterface, APIOptions from .recipe_implementation import RecipeImplementation +from ..multitenancy.constants import DEFAULT_TENANT_ID if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -126,6 +128,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, @@ -245,15 +248,56 @@ def reset(): def return_api_id_if_can_handle_request( self, path: NormalisedURLPath, method: str - ) -> Union[str, None]: + ) -> Union[ApiIdWithTenantId, None]: dashboard_bundle_path = self.app_info.api_base_path.append( NormalisedURLPath(DASHBOARD_API) ) - if is_api_path(path, self.app_info): - return get_api_if_matched(path, method) + base_path_str = self.app_info.api_base_path.get_as_string_dangerous() + path_str = path.get_as_string_dangerous() + regex = rf"^{base_path_str}(?:/([a-zA-Z0-9-]+))?(/.*)$" + # some examples against for above regex: + # books => match = None + # public/books => match = None + # /books => match.group(1) = None, match.group(2) = /dashboard + # /public/books => match.group(1) = 'public', match.group(2) = '/books' + # /public/book/1 => match.group(1) = 'public', match.group(2) = '/book/1' + + match = re.match(regex, path_str) + match_group_1 = match.group(1) if match is not None else None + match_group_2 = match.group(2) if match is not None else None + + tenant_id: str = DEFAULT_TENANT_ID + remaining_path: Optional[NormalisedURLPath] = None + + if ( + match is not None + and isinstance(match_group_1, str) + and isinstance(match_group_2, str) + ): + tenant_id = match_group_1 + remaining_path = NormalisedURLPath(match_group_2) + + if is_api_path(path, self.app_info.api_base_path) or ( + remaining_path is not None + and is_api_path( + path, + self.app_info.api_base_path.append(NormalisedURLPath(f"/{tenant_id}")), + ) + ): + # check remainingPath first as path that contains tenantId might match as well + # since getApiIdIfMatched uses endsWith to match + if remaining_path is not None: + id_ = get_api_if_matched(remaining_path, method) + if id_ is not None: + return ApiIdWithTenantId(id_, tenant_id) + + id_ = get_api_if_matched(path, method) + if id_ is not None: + return ApiIdWithTenantId(id_, DEFAULT_TENANT_ID) if path.startswith(dashboard_bundle_path): - return DASHBOARD_API + return ApiIdWithTenantId(DASHBOARD_API, DEFAULT_TENANT_ID) + # tenantId is not supported for bundlePath, so not matching for it return None diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 49392e57d..fbc3f79ad 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest - from ...supertokens import AppInfo from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.asyncio import ( @@ -195,10 +194,8 @@ def validate_and_normalise_user_input( ) -def is_api_path(path: NormalisedURLPath, app_info: AppInfo) -> bool: - dashboard_recipe_base_path = app_info.api_base_path.append( - NormalisedURLPath(DASHBOARD_API) - ) +def is_api_path(path: NormalisedURLPath, base_path: NormalisedURLPath) -> bool: + dashboard_recipe_base_path = base_path.append(NormalisedURLPath(DASHBOARD_API)) if not path.startswith(dashboard_recipe_base_path): return False diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index 7092c886b..cb04f0071 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -170,6 +170,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index d19f505b8..998bbe4b0 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -165,6 +165,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/jwt/recipe.py b/supertokens_python/recipe/jwt/recipe.py index b708499ff..c7fbc5571 100644 --- a/supertokens_python/recipe/jwt/recipe.py +++ b/supertokens_python/recipe/jwt/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Union, Optional from supertokens_python.querier import Querier from supertokens_python.recipe.jwt.api.implementation import APIImplementation @@ -80,6 +80,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/multitenancy/constants.py b/supertokens_python/recipe/multitenancy/constants.py index ee5191b55..08e1e0987 100644 --- a/supertokens_python/recipe/multitenancy/constants.py +++ b/supertokens_python/recipe/multitenancy/constants.py @@ -12,4 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. LOGIN_METHODS = "/loginmethods" -DEFAULT_TENANT_ID = "defaultTenantId" +DEFAULT_TENANT_ID = "public" diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index af34792b8..27b88e1be 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -122,6 +122,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 284f108ec..56559a3c7 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Union, Optional from supertokens_python.querier import Querier from supertokens_python.recipe.jwt import JWTRecipe @@ -89,6 +89,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, @@ -107,7 +108,7 @@ async def handle_api_request( self.api_implementation, options ) return await self.jwt_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) async def handle_error( diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index f79751054..8fd298a31 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union, Optional from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -185,6 +185,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 3d94eb805..46535c461 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -187,6 +187,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, @@ -215,7 +216,7 @@ async def handle_api_request( ), ) return await self.openid_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) async def handle_error( diff --git a/supertokens_python/recipe/thirdparty/providers/apple.py b/supertokens_python/recipe/thirdparty/providers/apple.py index 0d938c64e..2c30ed8a8 100644 --- a/supertokens_python/recipe/thirdparty/providers/apple.py +++ b/supertokens_python/recipe/thirdparty/providers/apple.py @@ -18,12 +18,8 @@ from jwt import encode # type: ignore from time import time -from ..provider import ( - Provider, - ProviderConfigForClientType, - ProviderInput, -) from .custom import GenericProvider, NewProvider +from ..provider import Provider, ProviderConfigForClientType, ProviderInput from .utils import get_actual_client_id_from_development_client_id @@ -54,7 +50,7 @@ async def _get_client_secret( # pylint: disable=no-self-use "Please ensure that keyId, teamId and privateKey are provided in the additionalConfig" ) - payload = { + payload: Dict[str, Any] = { "iss": config.additional_config.get("teamId"), "iat": time(), "exp": time() + (86400 * 180), # 6 months diff --git a/supertokens_python/recipe/thirdparty/providers/bitbucket.py b/supertokens_python/recipe/thirdparty/providers/bitbucket.py index 4742c0958..8fe2cfc95 100644 --- a/supertokens_python/recipe/thirdparty/providers/bitbucket.py +++ b/supertokens_python/recipe/thirdparty/providers/bitbucket.py @@ -19,12 +19,10 @@ from ..provider import Provider, ProviderInput -# TODO Implement when it's done in Node PR class BitbucketImpl(GenericProvider): pass -# TODO Finish when it's done in Node PR def Bitbucket(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin if input.config.name is None: input.config.name = "Bitbucket" @@ -40,7 +38,4 @@ def Bitbucket(input: ProviderInput) -> Provider: # pylint: disable=redefined-bu if input.config.user_info_endpoint is None: input.config.user_info_endpoint = "https://api.bitbucket.org/2.0/user" - # TODO overrides and working of this - # once done in Node PR - return NewProvider(input, BitbucketImpl) diff --git a/supertokens_python/recipe/thirdparty/providers/gitlab.py b/supertokens_python/recipe/thirdparty/providers/gitlab.py index 02723b525..373cf5c89 100644 --- a/supertokens_python/recipe/thirdparty/providers/gitlab.py +++ b/supertokens_python/recipe/thirdparty/providers/gitlab.py @@ -19,11 +19,9 @@ from ..provider import Provider, ProviderInput -# TODO Implement when it's done in Node PR class GitlabImpl(GenericProvider): pass -# TODO Implement when it's done in Node PR def Gitlab(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin return NewProvider(input, GitlabImpl) diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index 6573a8695..65bf187e9 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier @@ -121,6 +121,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/thirdpartyemailpassword/recipe.py b/supertokens_python/recipe/thirdpartyemailpassword/recipe.py index a02df228e..db07329a9 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/recipe.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Union, Optional from supertokens_python.framework.response import BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -210,6 +210,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, @@ -220,7 +221,7 @@ async def handle_api_request( is not None ): return await self.email_password_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) if ( self.third_party_recipe is not None @@ -230,7 +231,7 @@ async def handle_api_request( is not None ): return await self.third_party_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) return None diff --git a/supertokens_python/recipe/thirdpartyemailpassword/recipeimplementation/implementation.py b/supertokens_python/recipe/thirdpartyemailpassword/recipeimplementation/implementation.py index f6574d187..9db848056 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/recipeimplementation/implementation.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/recipeimplementation/implementation.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable +from typing import TYPE_CHECKING, Any, Dict, List, Union, Callable, Optional import supertokens_python.recipe.emailpassword.interfaces as EPInterfaces from supertokens_python.recipe.thirdparty.interfaces import GetProviderOkResult diff --git a/supertokens_python/recipe/thirdpartypasswordless/recipe.py b/supertokens_python/recipe/thirdpartypasswordless/recipe.py index b8188b3ad..14519f073 100644 --- a/supertokens_python/recipe/thirdpartypasswordless/recipe.py +++ b/supertokens_python/recipe/thirdpartypasswordless/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional from supertokens_python.framework.response import BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -224,6 +224,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, @@ -234,7 +235,7 @@ async def handle_api_request( is not None ): return await self.passwordless_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) if ( self.third_party_recipe is not None @@ -244,7 +245,7 @@ async def handle_api_request( is not None ): return await self.third_party_recipe.handle_api_request( - request_id, request, path, method, response + request_id, tenant_id, request, path, method, response ) return None diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index 2e666aeec..f36f02a80 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -15,7 +15,7 @@ from __future__ import annotations from os import environ -from typing import List, Union +from typing import List, Union, Optional from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse @@ -66,6 +66,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 43e0b69a3..8971c4dab 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -84,6 +84,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/recipe_module.py b/supertokens_python/recipe_module.py index eb0901173..39ea3e53e 100644 --- a/supertokens_python/recipe_module.py +++ b/supertokens_python/recipe_module.py @@ -15,7 +15,8 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, List, Union +import re +from typing import TYPE_CHECKING, List, Union, Optional from typing_extensions import Literal @@ -24,9 +25,16 @@ if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from .supertokens import AppInfo - from .normalised_url_path import NormalisedURLPath + from .exceptions import SuperTokensError +from .normalised_url_path import NormalisedURLPath + + +class ApiIdWithTenantId: + def __init__(self, api_id: str, tenant_id: Optional[str]): + self.api_id = api_id + self.tenant_id = tenant_id class RecipeModule(abc.ABC): @@ -42,17 +50,42 @@ def get_app_info(self): def return_api_id_if_can_handle_request( self, path: NormalisedURLPath, method: str - ) -> Union[str, None]: + ) -> Union[ApiIdWithTenantId, None]: + from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID + apis_handled = self.get_apis_handled() + + base_path_str = self.app_info.api_base_path.get_as_string_dangerous() + path_str = path.get_as_string_dangerous() + regex = rf"^{base_path_str}(?:/([a-zA-Z0-9-]+))?(/.*)$" + + match = re.match(regex, path_str) + match_group_1 = match.group(1) if match is not None else None + match_group_2 = match.group(2) if match is not None else None + + tenant_id: str = DEFAULT_TENANT_ID + remaining_path: Optional[NormalisedURLPath] = None + + if ( + match is not None + and isinstance(match_group_1, str) + and isinstance(match_group_2, str) + ): + tenant_id = match_group_1 + remaining_path = NormalisedURLPath(match_group_2) + for current_api in apis_handled: - if ( - not current_api.disabled - and current_api.method == method - and self.app_info.api_base_path.append( + if not current_api.disabled and current_api.method == method: + if self.app_info.api_base_path.append( current_api.path_without_api_base_path - ).equals(path) - ): - return current_api.request_id + ).equals(path): + return ApiIdWithTenantId(current_api.request_id, DEFAULT_TENANT_ID) + + if remaining_path is not None and self.app_info.api_base_path.append( + current_api.path_without_api_base_path + ).equals(self.app_info.api_base_path.append(remaining_path)): + return ApiIdWithTenantId(current_api.request_id, tenant_id) + return None @abc.abstractmethod @@ -67,6 +100,7 @@ def get_apis_handled(self) -> List[APIHandled]: async def handle_api_request( self, request_id: str, + tenant_id: Optional[str], request: BaseRequest, path: NormalisedURLPath, method: str, diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index df8e24118..6e362a7a8 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -489,7 +489,7 @@ async def middleware( # pylint: disable=no-self-use # see # https://github.com/supertokens/supertokens-python/issues/54 request_rid = None - request_id = None + api_and_tenant_id = None matched_recipe = None if request_rid is not None: for recipe in Supertokens.get_instance().recipe_modules: @@ -501,7 +501,7 @@ async def middleware( # pylint: disable=no-self-use matched_recipe = recipe break if matched_recipe is not None: - request_id = matched_recipe.return_api_id_if_can_handle_request( + api_and_tenant_id = matched_recipe.return_api_id_if_can_handle_request( path, method ) else: @@ -510,8 +510,10 @@ async def middleware( # pylint: disable=no-self-use "middleware: Checking recipe ID for match: %s", recipe.get_recipe_id(), ) - request_id = recipe.return_api_id_if_can_handle_request(path, method) - if request_id is not None: + api_and_tenant_id = recipe.return_api_id_if_can_handle_request( + path, method + ) + if api_and_tenant_id is not None: matched_recipe = recipe break if matched_recipe is not None: @@ -520,18 +522,25 @@ async def middleware( # pylint: disable=no-self-use ) else: log_debug_message("middleware: Not handling because no recipe matched") - if matched_recipe is not None and request_id is None: + + if matched_recipe is not None and api_and_tenant_id is None: log_debug_message( "middleware: Not handling because recipe doesn't handle request path or method. Request path: %s, request method: %s", path.get_as_string_dangerous(), method, ) - if request_id is not None and matched_recipe is not None: + if api_and_tenant_id is not None and matched_recipe is not None: log_debug_message( - "middleware: Request being handled by recipe. ID is: %s", request_id + "middleware: Request being handled by recipe. ID is: %s", + api_and_tenant_id.api_id, ) api_resp = await matched_recipe.handle_api_request( - request_id, request, path, method, response + api_and_tenant_id.api_id, + api_and_tenant_id.tenant_id, + request, + path, + method, + response, ) if api_resp is None: log_debug_message("middleware: Not handled because API returned None") diff --git a/tests/multitenancy/__init__.py b/tests/multitenancy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/multitenancy/test_router.py b/tests/multitenancy/test_router.py new file mode 100644 index 000000000..a08c6c751 --- /dev/null +++ b/tests/multitenancy/test_router.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from fastapi import FastAPI +from pytest import mark, fixture +from starlette.testclient import TestClient + +from supertokens_python import init +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe import session, emailpassword, dashboard +from tests.utils import setup_function, teardown_function, get_st_init_args, start_st + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio + + +@fixture(scope="function") +async def client(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app) + + +async def test_emailpassword_router(client: TestClient): + args = get_st_init_args( + [ + session.init(get_token_transfer_method=lambda *_: "cookie"), # type: ignore + emailpassword.init(), + ] + ) + init(**args) + start_st() + + res = client.post( + "/auth/public/signup", + headers={"Content-Type": "application/json"}, + json={ + "formFields": [ + {"id": "password", "value": "password1"}, + {"id": "email", "value": "test1@example.com"}, + ] + }, + ) + + assert res.status_code == 200 + assert res.json()["status"] == "OK" + + res = client.post( + "/auth/signup", + headers={"Content-Type": "application/json"}, + json={ + "formFields": [ + {"id": "password", "value": "password2"}, + {"id": "email", "value": "test2@example.com"}, + ] + }, + ) + + assert res.status_code == 200 + assert res.json()["status"] == "OK" + + +async def test_dashboard_apis_router(client: TestClient): + args = get_st_init_args( + [ + session.init(get_token_transfer_method=lambda *_: "cookie"), # type: ignore + emailpassword.init(), + dashboard.init(), + ] + ) + init(**args) + start_st() + + res = client.post( + "/auth/public/dashboard/api/signin", + headers={"Content-Type": "application/json"}, + json={ + "email": "test1@example.com", + "password": "password1", + }, + ) + + assert res.status_code == 200 + + res = client.post( + "/auth/dashboard/api/signin", + headers={"Content-Type": "application/json"}, + json={ + "email": "test1@example.com", + "password": "password1", + }, + ) + + assert res.status_code == 200 diff --git a/tests/thirdpartypasswordless/test_emaildelivery.py b/tests/thirdpartypasswordless/test_emaildelivery.py index d48938499..5ea27cbf0 100644 --- a/tests/thirdpartypasswordless/test_emaildelivery.py +++ b/tests/thirdpartypasswordless/test_emaildelivery.py @@ -39,9 +39,6 @@ session, thirdpartypasswordless, ) -from supertokens_python.recipe.emailverification.asyncio import ( - create_email_verification_token, -) from supertokens_python.recipe.emailverification.emaildelivery.services.smtp import ( SMTPService as EVSMTPService, ) @@ -61,6 +58,9 @@ RecipeImplementation as SessionRecipeImplementation, ) from supertokens_python.recipe.session.session_functions import create_new_session +from supertokens_python.recipe.emailverification.asyncio import ( + create_email_verification_token, +) from supertokens_python.recipe.thirdpartypasswordless.asyncio import ( passwordlessSigninup, thirdparty_manually_create_or_update_user,