diff --git a/supertokens_python/recipe/thirdpartyemailpassword/recipe.py b/supertokens_python/recipe/thirdpartyemailpassword/recipe.py index 99c089386..32e92f3c7 100644 --- a/supertokens_python/recipe/thirdpartyemailpassword/recipe.py +++ b/supertokens_python/recipe/thirdpartyemailpassword/recipe.py @@ -155,7 +155,7 @@ def apis_override_email_password(_: EmailPasswordAPIInterface): ) if third_party_recipe is not None: - self.third_party_recipe: Union[ThirdPartyRecipe, None] = third_party_recipe + self.third_party_recipe = third_party_recipe else: def func_override_third_party( @@ -168,18 +168,16 @@ def apis_override_third_party( ) -> ThirdPartyAPIInterface: return get_third_party_interface_impl(self.api_implementation) - self.third_party_recipe: Union[ThirdPartyRecipe, None] = None + # No email delivery ingredient required for third party recipe + # but we pass an object for future proofing tp_ingredients = ThirdPartyIngredients() - if len(self.config.providers) != 0: - self.third_party_recipe = ThirdPartyRecipe( - recipe_id, - app_info, - SignInAndUpFeature(self.config.providers), - tp_ingredients, - TPOverrideConfig( - func_override_third_party, apis_override_third_party - ), - ) + self.third_party_recipe = ThirdPartyRecipe( + recipe_id, + app_info, + SignInAndUpFeature(self.config.providers), + tp_ingredients, + TPOverrideConfig(func_override_third_party, apis_override_third_party), + ) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( @@ -188,17 +186,13 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: err ) or ( - self.third_party_recipe is not None - and self.third_party_recipe.is_error_from_this_recipe_based_on_instance( - err - ) + self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err) ) ) def get_apis_handled(self) -> List[APIHandled]: apis_handled = self.email_password_recipe.get_apis_handled() - if self.third_party_recipe is not None: - apis_handled = apis_handled + self.third_party_recipe.get_apis_handled() + apis_handled += self.third_party_recipe.get_apis_handled() return apis_handled async def handle_api_request( @@ -221,8 +215,7 @@ async def handle_api_request( request_id, tenant_id, request, path, method, response, user_context ) if ( - self.third_party_recipe is not None - and await self.third_party_recipe.return_api_id_if_can_handle_request( + await self.third_party_recipe.return_api_id_if_can_handle_request( path, method, user_context ) is not None @@ -237,17 +230,13 @@ async def handle_error( ) -> BaseResponse: if self.email_password_recipe.is_error_from_this_recipe_based_on_instance(err): return await self.email_password_recipe.handle_error(request, err, response) - if ( - self.third_party_recipe is not None - and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err) - ): + if self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err): return await self.third_party_recipe.handle_error(request, err, response) raise err def get_all_cors_headers(self) -> List[str]: cors_headers = self.email_password_recipe.get_all_cors_headers() - if self.third_party_recipe is not None: - cors_headers = cors_headers + self.third_party_recipe.get_all_cors_headers() + cors_headers += self.third_party_recipe.get_all_cors_headers() return cors_headers @staticmethod diff --git a/supertokens_python/recipe/thirdpartypasswordless/recipe.py b/supertokens_python/recipe/thirdpartypasswordless/recipe.py index 3dc703c61..a6a50338d 100644 --- a/supertokens_python/recipe/thirdpartypasswordless/recipe.py +++ b/supertokens_python/recipe/thirdpartypasswordless/recipe.py @@ -176,7 +176,7 @@ def apis_override_passwordless( ) if third_party_recipe is not None: - self.third_party_recipe: Union[ThirdPartyRecipe, None] = third_party_recipe + self.third_party_recipe = third_party_recipe else: def func_override_third_party( @@ -189,36 +189,30 @@ def apis_override_third_party( ) -> ThirdPartyAPIInterface: return get_third_party_interface_impl(self.api_implementation) - self.third_party_recipe: Union[ThirdPartyRecipe, None] = None - - if len(self.config.providers) != 0: - tp_ingredients = ThirdPartyIngredients() - self.third_party_recipe = ThirdPartyRecipe( - recipe_id, - app_info, - SignInAndUpFeature(self.config.providers), - tp_ingredients, - TPOverrideConfig( - func_override_third_party, apis_override_third_party - ), - ) + # Thirdparty recipe doesn't need ingredients + # as of now. But we are passing ingredients object + # so that it's future-proof. + tp_ingredients = ThirdPartyIngredients() + self.third_party_recipe = ThirdPartyRecipe( + recipe_id, + app_info, + SignInAndUpFeature(self.config.providers), + tp_ingredients, + TPOverrideConfig(func_override_third_party, apis_override_third_party), + ) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( isinstance(err, SupertokensThirdPartyPasswordlessError) or self.passwordless_recipe.is_error_from_this_recipe_based_on_instance(err) or ( - self.third_party_recipe is not None - and self.third_party_recipe.is_error_from_this_recipe_based_on_instance( - err - ) + self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err) ) ) def get_apis_handled(self) -> List[APIHandled]: apis_handled = self.passwordless_recipe.get_apis_handled() - if self.third_party_recipe is not None: - apis_handled = apis_handled + self.third_party_recipe.get_apis_handled() + apis_handled += self.third_party_recipe.get_apis_handled() return apis_handled async def handle_api_request( @@ -241,8 +235,7 @@ async def handle_api_request( request_id, tenant_id, request, path, method, response, user_context ) if ( - self.third_party_recipe is not None - and await self.third_party_recipe.return_api_id_if_can_handle_request( + await self.third_party_recipe.return_api_id_if_can_handle_request( path, method, user_context ) is not None @@ -257,17 +250,13 @@ async def handle_error( ) -> BaseResponse: if self.passwordless_recipe.is_error_from_this_recipe_based_on_instance(err): return await self.passwordless_recipe.handle_error(request, err, response) - if ( - self.third_party_recipe is not None - and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err) - ): + if self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err): return await self.third_party_recipe.handle_error(request, err, response) raise err def get_all_cors_headers(self) -> List[str]: cors_headers = self.passwordless_recipe.get_all_cors_headers() - if self.third_party_recipe is not None: - cors_headers = cors_headers + self.third_party_recipe.get_all_cors_headers() + cors_headers += self.third_party_recipe.get_all_cors_headers() return cors_headers @staticmethod diff --git a/tests/thirdparty/test_authorisation_url_feature.py b/tests/thirdparty/test_authorisation_url_feature.py new file mode 100644 index 000000000..2ddb27aa1 --- /dev/null +++ b/tests/thirdparty/test_authorisation_url_feature.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import json + +from fastapi import FastAPI +from pytest import mark, fixture + +from supertokens_python.framework.fastapi import get_middleware +from fastapi.testclient import TestClient +from supertokens_python.recipe import session, thirdparty +from supertokens_python import init +from supertokens_python.recipe.multitenancy.asyncio import ( + create_or_update_third_party_config, +) +from supertokens_python.recipe.thirdparty.provider import ( + ProviderConfig, + ProviderClientConfig, +) + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio + + +@fixture(scope="function") +async def app(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app) + + +async def test_calling_authorisation_url_api_with_empty_init(app: TestClient): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdparty.init(), + ] + ) + init(**args) # type: ignore + start_st() + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + assert res.status_code == 400 + assert res.text == "the provider google could not be found in the configuration" + + +async def test_calling_authorisation_url_api_with_empty_init_with_dynamic_thirdparty_provider( + app: TestClient, +): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdparty.init(), + ] + ) + init(**args) # type: ignore + start_st() + + await create_or_update_third_party_config( + "public", + ProviderConfig( + third_party_id="google", + name="Google", + clients=[ + ProviderClientConfig( + client_id="google-client-id", + client_secret="google-client-secret", + ) + ], + ), + ) + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + body = json.loads(res.text) + assert body["status"] == "OK" + assert ( + body["urlWithQueryParams"] + == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&redirect_uri=redirect&response_type=code&scope=openid+email&included_grant_scopes=true&access_type=offline" + ) diff --git a/tests/thirdpartyemailpassword/test_authorisation_url_feature.py b/tests/thirdpartyemailpassword/test_authorisation_url_feature.py new file mode 100644 index 000000000..950a3a456 --- /dev/null +++ b/tests/thirdpartyemailpassword/test_authorisation_url_feature.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import json + +from fastapi import FastAPI +from pytest import mark, fixture + +from supertokens_python.framework.fastapi import get_middleware +from fastapi.testclient import TestClient +from supertokens_python.recipe import session, thirdpartyemailpassword +from supertokens_python import init +from supertokens_python.recipe.multitenancy.asyncio import ( + create_or_update_third_party_config, +) +from supertokens_python.recipe.thirdparty.provider import ( + ProviderConfig, + ProviderClientConfig, +) + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio + + +@fixture(scope="function") +async def app(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app) + + +async def test_calling_authorisation_url_api_with_empty_init(app: TestClient): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdpartyemailpassword.init(), + ] + ) + init(**args) # type: ignore + start_st() + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + assert res.status_code == 400 + assert res.text == "the provider google could not be found in the configuration" + + +async def test_calling_authorisation_url_api_with_empty_init_with_dynamic_thirdparty_provider( + app: TestClient, +): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdpartyemailpassword.init(), + ] + ) + init(**args) # type: ignore + start_st() + + await create_or_update_third_party_config( + "public", + ProviderConfig( + third_party_id="google", + name="Google", + clients=[ + ProviderClientConfig( + client_id="google-client-id", + client_secret="google-client-secret", + ) + ], + ), + ) + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + body = json.loads(res.text) + assert body["status"] == "OK" + assert ( + body["urlWithQueryParams"] + == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&redirect_uri=redirect&response_type=code&scope=openid+email&included_grant_scopes=true&access_type=offline" + ) diff --git a/tests/thirdpartypasswordless/test_authorisation_url_feature.py b/tests/thirdpartypasswordless/test_authorisation_url_feature.py new file mode 100644 index 000000000..28e8d4167 --- /dev/null +++ b/tests/thirdpartypasswordless/test_authorisation_url_feature.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import json + +from fastapi import FastAPI +from pytest import mark, fixture + +from supertokens_python.framework.fastapi import get_middleware +from fastapi.testclient import TestClient +from supertokens_python.recipe import session, thirdpartypasswordless +from supertokens_python import init +from supertokens_python.recipe.multitenancy.asyncio import ( + create_or_update_third_party_config, +) +from supertokens_python.recipe.thirdparty.provider import ( + ProviderConfig, + ProviderClientConfig, +) + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio + + +@fixture(scope="function") +async def app(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app) + + +async def test_calling_authorisation_url_api_with_empty_init(app: TestClient): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdpartypasswordless.init( + contact_config=thirdpartypasswordless.ContactEmailOnlyConfig(), + flow_type="MAGIC_LINK", + ), + ] + ) + init(**args) # type: ignore + start_st() + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + assert res.status_code == 400 + assert res.text == "the provider google could not be found in the configuration" + + +async def test_calling_authorisation_url_api_with_empty_init_with_dynamic_thirdparty_provider( + app: TestClient, +): + args = get_st_init_args( + [ + session.init( + get_token_transfer_method=lambda _, __, ___: "cookie", + anti_csrf="VIA_TOKEN", + ), + thirdpartypasswordless.init( + contact_config=thirdpartypasswordless.ContactEmailOnlyConfig(), + flow_type="MAGIC_LINK", + ), + ] + ) + init(**args) # type: ignore + start_st() + + await create_or_update_third_party_config( + "public", + ProviderConfig( + third_party_id="google", + name="Google", + clients=[ + ProviderClientConfig( + client_id="google-client-id", + client_secret="google-client-secret", + ) + ], + ), + ) + + res = app.get( + "/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect" + ) + body = json.loads(res.json()) + assert body["status"] == "OK" + assert ( + body["urlWithQueryParams"] + == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&redirect_uri=redirect&response_type=code&scope=openid+email&included_grant_scopes=true&access_type=offline" + )