diff --git a/supertokens_python/always_initialised_recipes.py b/supertokens_python/always_initialised_recipes.py deleted file mode 100644 index 32fdee607..000000000 --- a/supertokens_python/always_initialised_recipes.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from __future__ import annotations -from typing import Callable, Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from supertokens_python.recipe_module import RecipeModule - from supertokens_python import AppInfo - -from supertokens_python.recipe.multitenancy import init - -DEFAULT_MULTITENANCY_RECIPE: Optional[Callable[[AppInfo], RecipeModule]] = init() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py index 9636f5d84..7d9a20ec4 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py @@ -17,7 +17,7 @@ async def handle_sessions_get( _api_interface: APIInterface, - _tenant_id: str, + tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> UserSessionsGetAPIResponse: @@ -26,9 +26,10 @@ async def handle_sessions_get( if user_id is None: raise_bad_input_exception("Missing required parameter 'userId'") - # TODO: Pass tenant id here + # Passing tenant id as None sets fetch_across_all_tenants to True + # which is what we want here. session_handles = await get_all_session_handles_for_user( - user_id, "pass-tenant-id", user_context + user_id, None, user_context ) sessions: List[Optional[SessionInfo]] = [None for _ in session_handles] diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index df0003405..68677f34c 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -22,8 +22,8 @@ CreateResetPasswordWrongUserIdError, CreateResetPasswordLinkUknownUserIdError, CreateResetPasswordLinkOkResult, - CreateResetPasswordEmailOkResult, - CreateResetPasswordEmailUnknownUserIdError, + SendResetPasswordEmailOkResult, + SendResetPasswordEmailUnknownUserIdError, ) from supertokens_python.recipe.emailpassword.utils import get_password_reset_link from supertokens_python.recipe.emailpassword.types import ( @@ -157,7 +157,7 @@ async def send_reset_password_email( ): link = await create_reset_password_link(tenant_id, user_id, user_context) if isinstance(link, CreateResetPasswordLinkUknownUserIdError): - return CreateResetPasswordEmailUnknownUserIdError() + return SendResetPasswordEmailUnknownUserIdError() user = await get_user_by_id(user_id, user_context) assert user is not None @@ -171,4 +171,4 @@ async def send_reset_password_email( user_context, ) - return CreateResetPasswordEmailOkResult() + return SendResetPasswordEmailOkResult() diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 4c0fd65f1..cc6201c37 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -66,11 +66,11 @@ class CreateResetPasswordLinkUknownUserIdError: pass -class CreateResetPasswordEmailOkResult: +class SendResetPasswordEmailOkResult: pass -class CreateResetPasswordEmailUnknownUserIdError: +class SendResetPasswordEmailUnknownUserIdError: pass diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 638e63927..fe6ba0f2e 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -230,7 +230,7 @@ def callback(): PostSTInitCallbacks.add_post_init_callback(callback) return EmailVerificationRecipe.__instance - return raise_general_exception( + raise_general_exception( "Emailverification recipe has already been initialised. Please check your code for bugs." ) diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index 8443e8363..3e8ae4871 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -152,8 +152,8 @@ async def validate_claims_for_session_handle( ) global_claim_validators = await resolve( recipe_impl.get_global_claim_validators( - session_info.user_id, session_info.tenant_id, + session_info.user_id, claim_validators_added_by_other_recipes, user_context, ) @@ -188,8 +188,8 @@ async def validate_claims_for_session_handle( async def validate_claims_in_jwt_payload( - user_id: str, tenant_id: str, + user_id: str, jwt_payload: JSONObject, override_global_claim_validators: Optional[ Callable[ @@ -213,8 +213,8 @@ async def validate_claims_in_jwt_payload( ) global_claim_validators = await resolve( recipe_impl.get_global_claim_validators( - user_id, tenant_id, + user_id, claim_validators_added_by_other_recipes, user_context, ) @@ -444,7 +444,7 @@ async def revoke_all_sessions_for_user( async def get_all_session_handles_for_user( user_id: str, - tenant_id: Optional[str], + tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: if user_context is None: diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 5b2d156da..cf4dc11b9 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -154,8 +154,8 @@ async def create_new_session( @abstractmethod def get_global_claim_validators( self, - user_id: str, tenant_id: str, + user_id: str, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: @@ -220,7 +220,7 @@ async def revoke_all_sessions_for_user( self, user_id: str, tenant_id: str, - revoke_across_all_tenants: Optional[bool], + revoke_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: pass @@ -230,7 +230,7 @@ async def get_all_session_handles_for_user( self, user_id: str, tenant_id: str, - fetch_across_all_tenants: Optional[bool], + fetch_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: pass diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 1feab414a..f487dcb56 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -330,7 +330,7 @@ async def revoke_all_sessions_for_user( self, user_id: str, tenant_id: Optional[str], - revoke_across_all_tenants: Optional[bool], + revoke_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.revoke_all_sessions_for_user( @@ -341,7 +341,7 @@ async def get_all_session_handles_for_user( self, user_id: str, tenant_id: Optional[str], - fetch_across_all_tenants: Optional[bool], + fetch_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.get_all_session_handles_for_user( @@ -437,8 +437,8 @@ async def get_claim_value( def get_global_claim_validators( self, - user_id: str, tenant_id: str, + user_id: str, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index 2bad06285..e904f14c4 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -223,8 +223,9 @@ async def fetch_and_set_claim( if user_context is None: user_context = {} - # TODO: Pass tenant id - update = await claim.build(self.get_user_id(), "pass-tenant-id", user_context) + update = await claim.build( + self.get_user_id(), self.get_tenant_id(), user_context + ) return await self.merge_into_access_token_payload(update, user_context) async def set_claim_value( diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index 10cb03833..0e391fc86 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -389,7 +389,7 @@ async def revoke_all_sessions_for_user( recipe_implementation: RecipeImplementation, user_id: str, tenant_id: Optional[str], - revoke_across_all_tenants: Optional[bool], + revoke_across_all_tenants: bool, ) -> List[str]: if tenant_id is None: tenant_id = DEFAULT_TENANT_ID @@ -405,7 +405,7 @@ async def get_all_session_handles_for_user( recipe_implementation: RecipeImplementation, user_id: str, tenant_id: Optional[str], - fetch_across_all_tenants: Optional[bool], + fetch_across_all_tenants: bool, ) -> List[str]: if tenant_id is None: tenant_id = DEFAULT_TENANT_ID diff --git a/supertokens_python/recipe/session/syncio/__init__.py b/supertokens_python/recipe/session/syncio/__init__.py index 6c3ecea0f..3a7557d6e 100644 --- a/supertokens_python/recipe/session/syncio/__init__.py +++ b/supertokens_python/recipe/session/syncio/__init__.py @@ -385,8 +385,8 @@ def validate_claims_for_session_handle( def validate_claims_in_jwt_payload( - user_id: str, tenant_id: str, + user_id: str, jwt_payload: JSONObject, override_global_claim_validators: Optional[ Callable[ @@ -402,8 +402,8 @@ def validate_claims_in_jwt_payload( return sync( async_validate_claims_in_jwt_payload( - user_id, tenant_id, + user_id, jwt_payload, override_global_claim_validators, user_context, diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index eab3c39cc..fcfe6a330 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -480,8 +480,8 @@ async def get_required_claim_validators( ) global_claim_validators = await resolve( SessionRecipe.get_instance().recipe_implementation.get_global_claim_validators( - session.get_user_id(), session.get_tenant_id(), + session.get_user_id(), claim_validators_added_by_other_recipes, user_context, ) diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index 4caec19b0..7135067ce 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +from supertokens_python.utils import utf_base64decode from base64 import b64decode import json diff --git a/supertokens_python/recipe/thirdpartyemailpassword/asyncio/__init__.py b/supertokens_python/recipe/thirdpartyemailpassword/asyncio/__init__.py index 438a2fd16..b6aeb4a78 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/asyncio/__init__.py @@ -25,9 +25,8 @@ CreateResetPasswordWrongUserIdError, CreateResetPasswordLinkUknownUserIdError, CreateResetPasswordLinkOkResult, - CreateResetPasswordEmailUnknownUserIdError, - CreateResetPasswordEmailOkResult, - RawUserInfoFromProvider, + SendResetPasswordEmailUnknownUserIdError, + SendResetPasswordEmailEmailOkResult ) from supertokens_python.recipe.emailpassword.utils import get_password_reset_link @@ -62,30 +61,6 @@ async def get_user_by_third_party_info( user_context, ) - -async def thirdparty_sign_in_up( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - email: str, - oauth_tokens: Dict[str, Any], - raw_user_info_from_provider: RawUserInfoFromProvider, - user_context: Optional[Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - - return await ThirdPartyEmailPasswordRecipe.get_instance().recipe_implementation.thirdparty_sign_in_up( - third_party_id, - third_party_user_id, - email, - oauth_tokens, - raw_user_info_from_provider, - tenant_id or DEFAULT_TENANT_ID, - user_context, - ) - - async def thirdparty_manually_create_or_update_user( tenant_id: str, third_party_id: str, @@ -239,7 +214,7 @@ async def send_reset_password_email( ): link = await create_reset_password_link(tenant_id, user_id, user_context) if isinstance(link, CreateResetPasswordLinkUknownUserIdError): - return CreateResetPasswordEmailUnknownUserIdError() + return SendResetPasswordEmailUnknownUserIdError() user = await get_user_by_id(user_id, user_context) assert user is not None @@ -253,4 +228,4 @@ async def send_reset_password_email( user_context, ) - return CreateResetPasswordEmailOkResult() + return SendResetPasswordEmailEmailOkResult() diff --git a/supertokens_python/recipe/thirdpartyemailpassword/interfaces.py b/supertokens_python/recipe/thirdpartyemailpassword/interfaces.py index 2aa4bc2ac..9fe2fc8bb 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/interfaces.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/interfaces.py @@ -21,9 +21,9 @@ CreateResetPasswordLinkUknownUserIdError = ( EPInterfaces.CreateResetPasswordLinkUknownUserIdError ) -CreateResetPasswordEmailOkResult = EPInterfaces.CreateResetPasswordEmailOkResult -CreateResetPasswordEmailUnknownUserIdError = ( - EPInterfaces.CreateResetPasswordEmailUnknownUserIdError +SendResetPasswordEmailEmailOkResult = EPInterfaces.SendResetPasswordEmailOkResult +SendResetPasswordEmailUnknownUserIdError = ( + EPInterfaces.SendResetPasswordEmailUnknownUserIdError ) EmailPasswordEmailExistsGetOkResult = EPInterfaces.EmailExistsGetOkResult GeneratePasswordResetTokenPostOkResult = ( diff --git a/supertokens_python/recipe/thirdpartyemailpassword/syncio/__init__.py b/supertokens_python/recipe/thirdpartyemailpassword/syncio/__init__.py index 1f061e771..5844263ac 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/syncio/__init__.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/syncio/__init__.py @@ -19,7 +19,6 @@ from ..interfaces import ( EmailPasswordSignInOkResult, EmailPasswordSignInWrongCredentialsError, - RawUserInfoFromProvider, ) from ..types import EmailTemplateVars, User @@ -49,32 +48,6 @@ def get_user_by_third_party_info( ) -def thirdparty_sign_in_up( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - email: str, - oauth_tokens: Dict[str, Any], - raw_user_info_from_provider: RawUserInfoFromProvider, - user_context: Optional[Dict[str, Any]] = None, -): - from supertokens_python.recipe.thirdpartyemailpassword.asyncio import ( - thirdparty_sign_in_up, - ) - - return sync( - thirdparty_sign_in_up( - tenant_id, - third_party_id, - third_party_user_id, - email, - oauth_tokens, - raw_user_info_from_provider, - user_context, - ) - ) - - def thirdparty_manually_create_or_update_user( tenant_id: str, third_party_id: str, diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 32502ee17..f4bae1d7f 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -152,7 +152,7 @@ async def fetch_value( recipe = UserRolesRecipe.get_instance() user_roles = await recipe.recipe_implementation.get_roles_for_user( - tenant_id, user_id, user_context + user_id, tenant_id, user_context ) user_permissions: Set[str] = set() @@ -186,7 +186,7 @@ async def fetch_value( ) -> List[str]: recipe = UserRolesRecipe.get_instance() res = await recipe.recipe_implementation.get_roles_for_user( - tenant_id, user_id, user_context + user_id, tenant_id, user_context ) return res.roles diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 1a8c3afcc..244afcd8f 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -46,7 +46,6 @@ send_non_200_response_with_message, ) - if TYPE_CHECKING: from .recipe_module import RecipeModule from supertokens_python.framework.request import BaseRequest @@ -152,8 +151,6 @@ def __init__( mode: Union[Literal["asgi", "wsgi"], None], telemetry: Union[bool, None], ): - from .always_initialised_recipes import DEFAULT_MULTITENANCY_RECIPE - if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") @@ -189,21 +186,21 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + multitenancy_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: nonlocal multitenancy_found recipe_module = recipe(self.app_info) - if recipe_module.get_recipe_id() == "multitenancy": + if recipe_module.get_recipe_id() == MultitenancyRecipe.recipe_id: multitenancy_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) - if callable(DEFAULT_MULTITENANCY_RECIPE) and not multitenancy_found: - recipe = DEFAULT_MULTITENANCY_RECIPE( # pylint: disable=not-callable - self.app_info - ) + if not multitenancy_found: + recipe = MultitenancyRecipe.init()(self.app_info) self.recipe_modules.append(recipe) self.telemetry = ( diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index cb71fd8ea..c22a2b1f2 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -16,6 +16,7 @@ from urllib.parse import urlencode from datetime import datetime from inspect import isawaitable +from base64 import b64encode from typing import Any, Dict, Union from django.http import HttpRequest, HttpResponse, JsonResponse @@ -455,10 +456,10 @@ async def test_thirdparty_parsing_works(self): start_st() - data = { - "state": "afc596274293e1587315c", - "code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA", - } + state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode() + code = "testing" + + data = { "state": state, "code": code} request = self.factory.post( "/auth/callback/apple", @@ -470,11 +471,9 @@ async def test_thirdparty_parsing_works(self): raise Exception("Should never come here") response = await temp - self.assertEqual(response.status_code, 200) - self.assertEqual( - response.content, - b'', - ) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.content, b'') + self.assertEqual(response.headers['location'], f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}") @pytest.mark.asyncio async def test_search_with_multiple_emails(self): diff --git a/tests/Flask/test_flask.py b/tests/Flask/test_flask.py index 75c08a4f3..03964c96a 100644 --- a/tests/Flask/test_flask.py +++ b/tests/Flask/test_flask.py @@ -14,6 +14,7 @@ import json from typing import Any, Dict, Union +from base64 import b64encode import pytest from _pytest.fixtures import fixture @@ -477,17 +478,15 @@ def test_thirdparty_parsing_works(driver_config_app: Any): start_st() test_client = driver_config_app.test_client() - data = { - "state": "afc596274293e1587315c", - "code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA", - } - response = test_client.post("/auth/callback/apple", data=data) + state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode() + code = "testing" - assert response.status_code == 200 - assert ( - response.data - == b'' - ) + data = { "state": state, "code": code} + res = test_client.post("/auth/callback/apple", data=data) + + assert res.status_code == 303 + assert res.data == b'' + assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}" from flask.wrappers import Response diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index 5d32beb6c..89ab9fbe7 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -69,6 +69,6 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp s = await create_new_session("public", dummy_req, "someId") payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"v": True, "t": timestamp} assert payload["user-custom-claim"] == "foo" diff --git a/tests/sessions/claims/test_get_claim_value.py b/tests/sessions/claims/test_get_claim_value.py index 0fa5bafdb..2a2a8cab0 100644 --- a/tests/sessions/claims/test_get_claim_value.py +++ b/tests/sessions/claims/test_get_claim_value.py @@ -56,5 +56,5 @@ async def test_should_work_for_non_existing_handle(): init(**new_st_init) # type: ignore start_st() - res = await get_claim_value("non_existing_handle", TrueClaim) + res = await get_claim_value("non-existing-handle", TrueClaim) assert isinstance(res, SessionDoesNotExistError) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 63c0859f2..ac1469b56 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -81,7 +81,7 @@ async def test_primitive_claim_fetch_value_params_correct(): user_id, ctx = "user_id", {} await claim.build(user_id, DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 - assert (user_id, ctx) == sync_fetch_value.call_args_list[0][ + assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 30f672dc7..2ee8d98ba 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -48,7 +48,7 @@ async def test_primitive_claim_fetch_value_params_correct(): user_id, ctx = "user_id", {} await claim.build(user_id, DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 - assert (user_id, ctx) == sync_fetch_value.call_args_list[0][ + assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 9e958b9ad..6fd70c643 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -60,14 +60,14 @@ async def test_should_overwrite_claim_value(timestamp: int): s = await create_new_session("public", dummy_req, "someId") payload = s.get_access_token_payload() - assert len(payload) == 9 + assert len(payload) == 10 assert payload["st-true"] == {"t": timestamp, "v": True} await s.set_claim_value(TrueClaim, False) # Payload should be updated now: payload = s.get_access_token_payload() - assert len(payload) == 9 + assert len(payload) == 10 assert payload["st-true"] == {"t": timestamp, "v": False} @@ -79,7 +79,7 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int) s = await create_new_session("public", dummy_req, "someId") payload = s.get_access_token_payload() - assert len(payload) == 9 + assert len(payload) == 10 assert payload["st-true"] == {"t": timestamp, "v": True} await set_claim_value(s.get_handle(), TrueClaim, False) diff --git a/tests/sessions/claims/test_validate_claims_for_session_handle.py b/tests/sessions/claims/test_validate_claims_for_session_handle.py index 35edaa6b8..86370efcc 100644 --- a/tests/sessions/claims/test_validate_claims_for_session_handle.py +++ b/tests/sessions/claims/test_validate_claims_for_session_handle.py @@ -59,6 +59,6 @@ async def test_should_work_for_not_existing_handle(): start_st() res = await validate_claims_for_session_handle( - "non_existing_handle", lambda _, __, ___: [] + "non-existing-handle", lambda _, __, ___: [] ) assert isinstance(res, SessionDoesNotExistError) diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 881a4b020..fd6ae7d65 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -8,8 +8,8 @@ from supertokens_python.recipe.session.interfaces import RecipeInterface from tests.utils import st_init_common_args -TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __: True) # type: ignore -NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __: None) # type: ignore +TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore +NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___: None) # type: ignore def session_functions_override_with_claim( diff --git a/tests/test_session.py b/tests/test_session.py index fd803e312..35a691248 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -202,7 +202,7 @@ async def test_creating_many_sessions_for_one_user_and_looping(): assert len(session_handles) == 7 - for i, handle in enumerate(session_handles): + for handle in session_handles: info = await get_session_information(handle) assert info is not None assert info.user_id == "someUser" @@ -224,19 +224,22 @@ async def test_creating_many_sessions_for_one_user_and_looping(): assert info.custom_claims_in_access_token_payload == {"someKey2": "someValue"} assert info.session_data_in_database == {"foo": "bar"} + regenerated_session_handles: List[str] = [] # Regenerate access token with new access_token_payload - for i, token in enumerate(access_tokens): + for token in access_tokens: result = await regenerate_access_token(token, {"bar": "baz"}) assert result is not None - assert ( - result.session.handle == session_handles[i] - ) # Session handle should remain the same + regenerated_session_handles.append(result.session.handle) # Confirm that update worked: info = await get_session_information(result.session.handle) assert info is not None assert info.custom_claims_in_access_token_payload == {"bar": "baz"} + # Session handle should remain the same session handle should remain the same + # but order isn't guaranteed so we should sort them + assert sorted(regenerated_session_handles) == sorted(session_handles) + # Try updating invalid handles: is_updated = await merge_into_access_token_payload("invalidHandle", {"foo": "bar"}) assert is_updated is False diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index 212e012ce..7e8584e25 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -1,4 +1,5 @@ import respx +import json from pytest import fixture, mark from fastapi import FastAPI @@ -7,6 +8,7 @@ from supertokens_python.recipe import session, thirdparty from supertokens_python import init +from base64 import b64encode from tests.utils import ( setup_function, @@ -65,15 +67,12 @@ async def test_thirdpary_parsing_works(fastapi_client: TestClient): init(**st_init_args) # type: ignore start_st() - data = { - "state": "afc596274293e1587315c", - "code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA", - } + state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode() + code = "testing" + data = { "state": state, "code": code} res = fastapi_client.post("/auth/callback/apple", data=data) - assert res.status_code == 200 - assert ( - res.content - == b'' - ) + assert res.status_code == 303 + assert res.content == b'' + assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"