From 09b0320ec0b057c7d2259ee8b24d71e8687edb4a Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 13 Jul 2023 19:05:26 +0530 Subject: [PATCH] refactor: Add tenant_id variable in session functions --- .../recipe/emailverification/recipe.py | 4 +- .../multitenancy/allowed_domains_claim.py | 65 ------------------- .../recipe/multitenancy/recipe.py | 48 +++++++------- .../claim_base_classes/boolean_claim.py | 2 +- .../recipe/session/recipe_implementation.py | 2 +- .../recipe/session/session_class.py | 2 +- .../session/session_request_functions.py | 2 +- 7 files changed, 32 insertions(+), 93 deletions(-) delete mode 100644 supertokens_python/recipe/multitenancy/allowed_domains_claim.py diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 998bbe4b0..89631ad3b 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -312,7 +312,9 @@ class EmailVerificationClaimClass(BooleanClaim): def __init__(self): default_max_age_in_sec = 300 - async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool: + async def fetch_value( + user_id: str, _tenant_id: str, user_context: Dict[str, Any] + ) -> bool: recipe = EmailVerificationRecipe.get_instance() email_info = await recipe.get_email_for_user_id(user_id, user_context) diff --git a/supertokens_python/recipe/multitenancy/allowed_domains_claim.py b/supertokens_python/recipe/multitenancy/allowed_domains_claim.py deleted file mode 100644 index 2fda624e7..000000000 --- a/supertokens_python/recipe/multitenancy/allowed_domains_claim.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from typing import Any, Callable, Dict, List, Optional, Union -from supertokens_python.recipe.session.claim_base_classes.primitive_array_claim import ( - PrimitiveArrayClaim, -) -from supertokens_python.recipe.session.interfaces import JSONObject -from supertokens_python.utils import get_timestamp_ms - -from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe - - -class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]): - def __init__(self): - default_max_age_in_sec = 60 * 60 * 24 * 7 - - async def fetch_value( - _: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Optional[List[str]]: - recipe = MultitenancyRecipe.get_instance() - - if recipe.get_allowed_domains_for_tenant_id is None: - # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default - return None - - return await recipe.get_allowed_domains_for_tenant_id( - tenant_id, user_context - ) - - super().__init__("st-t-dmns", fetch_value, default_max_age_in_sec) - - def get_value_from_payload( - self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None - ) -> Optional[List[str]]: - _ = user_context - - res = payload.get(self.key, {}).get("v") - if res is None: - return [] - return res - - def get_last_refetch_time( - self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None - ) -> Optional[int]: - _ = user_context - - res = payload.get(self.key, {}).get("t") - if res is None: - return get_timestamp_ms() - - return res - - -AllowedDomainsClaim = AllowedDomainsClaimClass() diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 4184a8a0e..9b41b6115 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -259,41 +259,43 @@ async def login_methods_get( class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]): def __init__(self): - async def fetch_value(_user_id: str, user_context: Dict[str, Any]) -> List[str]: + default_max_age_in_sec = 60 * 60 * 24 * 7 + + async def fetch_value( + _: str, tenant_id: str, user_context: Dict[str, Any] + ) -> Optional[List[str]]: recipe = MultitenancyRecipe.get_instance() - tenant_id = ( - None # TODO fetch value will be passed with tenant_id as well later - ) - if recipe.config.get_allowed_domains_for_tenant_id is None: - return ( - [] - ) # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default + if recipe.get_allowed_domains_for_tenant_id is None: + # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default + return None - domains_res = await recipe.config.get_allowed_domains_for_tenant_id( + return await recipe.get_allowed_domains_for_tenant_id( tenant_id, user_context ) - return domains_res - super().__init__( - key="st-tenant-domains", - fetch_value=fetch_value, - default_max_age_in_sec=3600, - ) + super().__init__("st-t-dmns", fetch_value, default_max_age_in_sec) def get_value_from_payload( - self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Union[List[str], None]: - if self.key not in payload: + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> Optional[List[str]]: + _ = user_context + + res = payload.get(self.key, {}).get("v") + if res is None: return [] - return super().get_value_from_payload(payload, user_context) + return res def get_last_refetch_time( - self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Union[int, None]: - if self.key not in payload: + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> Optional[int]: + _ = user_context + + res = payload.get(self.key, {}).get("t") + if res is None: return get_timestamp_ms() - return super().get_last_refetch_time(payload, user_context) + + return res AllowedDomainsClaim = AllowedDomainsClaimClass() diff --git a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py index 14c40aefc..aec6c6f71 100644 --- a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py @@ -31,7 +31,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, Dict[str, Any]], + [str, str, Dict[str, Any]], MaybeAwaitable[Optional[bool]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index c4c1327d7..a032df806 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -384,7 +384,7 @@ async def fetch_and_set_claim( return False access_token_payload_update = await claim.build( - session_info.user_id, user_context + session_info.user_id, tenant_id, user_context ) return await self.merge_into_access_token_payload( session_handle, access_token_payload_update, user_context diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index 2dceaa178..5209f3ae8 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -220,7 +220,7 @@ async def fetch_and_set_claim( if user_context is None: user_context = {} - update = await claim.build(self.get_user_id(), user_context) + update = await claim.build(self.get_user_id(), tenant_id, user_context) return await self.merge_into_access_token_payload(update, user_context) async def set_claim_value( diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index f80e36df1..828037c7d 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -238,7 +238,7 @@ async def create_new_session_in_request( final_access_token_payload = {**access_token_payload, "iss": issuer} for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, user_context) + update = await claim.build(user_id, tenant_id, user_context) final_access_token_payload = {**final_access_token_payload, **update} log_debug_message("createNewSession: Access token payload built")