Skip to content

Commit

Permalink
allows empty provider list
Browse files Browse the repository at this point in the history
  • Loading branch information
KShivendu committed Aug 10, 2023
1 parent e9bb86f commit 89fb919
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 54 deletions.
41 changes: 15 additions & 26 deletions supertokens_python/recipe/thirdpartyemailpassword/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def apis_override_email_password(_: EmailPasswordAPIInterface):
)

if third_party_recipe is not None:
self.third_party_recipe: Union[ThirdPartyRecipe, None] = third_party_recipe
self.third_party_recipe = third_party_recipe
else:

def func_override_third_party(
Expand All @@ -168,18 +168,16 @@ def apis_override_third_party(
) -> ThirdPartyAPIInterface:
return get_third_party_interface_impl(self.api_implementation)

self.third_party_recipe: Union[ThirdPartyRecipe, None] = None
# No email delivery ingredient required for third party recipe
# but we pass an object for future proofing
tp_ingredients = ThirdPartyIngredients()
if len(self.config.providers) != 0:
self.third_party_recipe = ThirdPartyRecipe(
recipe_id,
app_info,
SignInAndUpFeature(self.config.providers),
tp_ingredients,
TPOverrideConfig(
func_override_third_party, apis_override_third_party
),
)
self.third_party_recipe = ThirdPartyRecipe(
recipe_id,
app_info,
SignInAndUpFeature(self.config.providers),
tp_ingredients,
TPOverrideConfig(func_override_third_party, apis_override_third_party),
)

def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
return isinstance(err, SuperTokensError) and (
Expand All @@ -188,17 +186,13 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
err
)
or (
self.third_party_recipe is not None
and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(
err
)
self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err)
)
)

def get_apis_handled(self) -> List[APIHandled]:
apis_handled = self.email_password_recipe.get_apis_handled()
if self.third_party_recipe is not None:
apis_handled = apis_handled + self.third_party_recipe.get_apis_handled()
apis_handled += self.third_party_recipe.get_apis_handled()
return apis_handled

async def handle_api_request(
Expand All @@ -221,8 +215,7 @@ async def handle_api_request(
request_id, tenant_id, request, path, method, response, user_context
)
if (
self.third_party_recipe is not None
and await self.third_party_recipe.return_api_id_if_can_handle_request(
await self.third_party_recipe.return_api_id_if_can_handle_request(
path, method, user_context
)
is not None
Expand All @@ -237,17 +230,13 @@ async def handle_error(
) -> BaseResponse:
if self.email_password_recipe.is_error_from_this_recipe_based_on_instance(err):
return await self.email_password_recipe.handle_error(request, err, response)
if (
self.third_party_recipe is not None
and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err)
):
if self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err):
return await self.third_party_recipe.handle_error(request, err, response)
raise err

def get_all_cors_headers(self) -> List[str]:
cors_headers = self.email_password_recipe.get_all_cors_headers()
if self.third_party_recipe is not None:
cors_headers = cors_headers + self.third_party_recipe.get_all_cors_headers()
cors_headers += self.third_party_recipe.get_all_cors_headers()
return cors_headers

@staticmethod
Expand Down
45 changes: 17 additions & 28 deletions supertokens_python/recipe/thirdpartypasswordless/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def apis_override_passwordless(
)

if third_party_recipe is not None:
self.third_party_recipe: Union[ThirdPartyRecipe, None] = third_party_recipe
self.third_party_recipe = third_party_recipe
else:

def func_override_third_party(
Expand All @@ -189,36 +189,30 @@ def apis_override_third_party(
) -> ThirdPartyAPIInterface:
return get_third_party_interface_impl(self.api_implementation)

self.third_party_recipe: Union[ThirdPartyRecipe, None] = None

if len(self.config.providers) != 0:
tp_ingredients = ThirdPartyIngredients()
self.third_party_recipe = ThirdPartyRecipe(
recipe_id,
app_info,
SignInAndUpFeature(self.config.providers),
tp_ingredients,
TPOverrideConfig(
func_override_third_party, apis_override_third_party
),
)
# Thirdparty recipe doesn't need ingredients
# as of now. But we are passing ingredients object
# so that it's future-proof.
tp_ingredients = ThirdPartyIngredients()
self.third_party_recipe = ThirdPartyRecipe(
recipe_id,
app_info,
SignInAndUpFeature(self.config.providers),
tp_ingredients,
TPOverrideConfig(func_override_third_party, apis_override_third_party),
)

def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
return isinstance(err, SuperTokensError) and (
isinstance(err, SupertokensThirdPartyPasswordlessError)
or self.passwordless_recipe.is_error_from_this_recipe_based_on_instance(err)
or (
self.third_party_recipe is not None
and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(
err
)
self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err)
)
)

def get_apis_handled(self) -> List[APIHandled]:
apis_handled = self.passwordless_recipe.get_apis_handled()
if self.third_party_recipe is not None:
apis_handled = apis_handled + self.third_party_recipe.get_apis_handled()
apis_handled += self.third_party_recipe.get_apis_handled()
return apis_handled

async def handle_api_request(
Expand All @@ -241,8 +235,7 @@ async def handle_api_request(
request_id, tenant_id, request, path, method, response, user_context
)
if (
self.third_party_recipe is not None
and await self.third_party_recipe.return_api_id_if_can_handle_request(
await self.third_party_recipe.return_api_id_if_can_handle_request(
path, method, user_context
)
is not None
Expand All @@ -257,17 +250,13 @@ async def handle_error(
) -> BaseResponse:
if self.passwordless_recipe.is_error_from_this_recipe_based_on_instance(err):
return await self.passwordless_recipe.handle_error(request, err, response)
if (
self.third_party_recipe is not None
and self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err)
):
if self.third_party_recipe.is_error_from_this_recipe_based_on_instance(err):
return await self.third_party_recipe.handle_error(request, err, response)
raise err

def get_all_cors_headers(self) -> List[str]:
cors_headers = self.passwordless_recipe.get_all_cors_headers()
if self.third_party_recipe is not None:
cors_headers = cors_headers + self.third_party_recipe.get_all_cors_headers()
cors_headers += self.third_party_recipe.get_all_cors_headers()
return cors_headers

@staticmethod
Expand Down
110 changes: 110 additions & 0 deletions tests/thirdparty/test_authorisation_url_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json

from fastapi import FastAPI
from pytest import mark, fixture

from supertokens_python.framework.fastapi import get_middleware
from fastapi.testclient import TestClient
from supertokens_python.recipe import session, thirdparty
from supertokens_python import init
from supertokens_python.recipe.multitenancy.asyncio import (
create_or_update_third_party_config,
)
from supertokens_python.recipe.thirdparty.provider import (
ProviderConfig,
ProviderClientConfig,
)

from tests.utils import get_st_init_args
from tests.utils import (
setup_function,
teardown_function,
start_st,
)


_ = setup_function
_ = teardown_function

pytestmark = mark.asyncio


@fixture(scope="function")
async def app():
app = FastAPI()
app.add_middleware(get_middleware())

return TestClient(app)


async def test_calling_authorisation_url_api_with_empty_init(app: TestClient):
args = get_st_init_args(
[
session.init(
get_token_transfer_method=lambda _, __, ___: "cookie",
anti_csrf="VIA_TOKEN",
),
thirdparty.init(),
]
)
init(**args) # type: ignore
start_st()

res = app.get(
"/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect"
)
assert res.status_code == 400
assert res.text == "the provider google could not be found in the configuration"


async def test_calling_authorisation_url_api_with_empty_init_with_dynamic_thirdparty_provider(
app: TestClient,
):
args = get_st_init_args(
[
session.init(
get_token_transfer_method=lambda _, __, ___: "cookie",
anti_csrf="VIA_TOKEN",
),
thirdparty.init(),
]
)
init(**args) # type: ignore
start_st()

await create_or_update_third_party_config(
"public",
ProviderConfig(
third_party_id="google",
name="Google",
clients=[
ProviderClientConfig(
client_id="google-client-id",
client_secret="google-client-secret",
)
],
),
)

res = app.get(
"/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect"
)
body = json.loads(res.text)
assert body["status"] == "OK"
assert (
body["urlWithQueryParams"]
== "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&redirect_uri=redirect&response_type=code&scope=openid+email&included_grant_scopes=true&access_type=offline"
)
110 changes: 110 additions & 0 deletions tests/thirdpartyemailpassword/test_authorisation_url_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json

from fastapi import FastAPI
from pytest import mark, fixture

from supertokens_python.framework.fastapi import get_middleware
from fastapi.testclient import TestClient
from supertokens_python.recipe import session, thirdpartyemailpassword
from supertokens_python import init
from supertokens_python.recipe.multitenancy.asyncio import (
create_or_update_third_party_config,
)
from supertokens_python.recipe.thirdparty.provider import (
ProviderConfig,
ProviderClientConfig,
)

from tests.utils import get_st_init_args
from tests.utils import (
setup_function,
teardown_function,
start_st,
)


_ = setup_function
_ = teardown_function

pytestmark = mark.asyncio


@fixture(scope="function")
async def app():
app = FastAPI()
app.add_middleware(get_middleware())

return TestClient(app)


async def test_calling_authorisation_url_api_with_empty_init(app: TestClient):
args = get_st_init_args(
[
session.init(
get_token_transfer_method=lambda _, __, ___: "cookie",
anti_csrf="VIA_TOKEN",
),
thirdpartyemailpassword.init(),
]
)
init(**args) # type: ignore
start_st()

res = app.get(
"/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect"
)
assert res.status_code == 400
assert res.text == "the provider google could not be found in the configuration"


async def test_calling_authorisation_url_api_with_empty_init_with_dynamic_thirdparty_provider(
app: TestClient,
):
args = get_st_init_args(
[
session.init(
get_token_transfer_method=lambda _, __, ___: "cookie",
anti_csrf="VIA_TOKEN",
),
thirdpartyemailpassword.init(),
]
)
init(**args) # type: ignore
start_st()

await create_or_update_third_party_config(
"public",
ProviderConfig(
third_party_id="google",
name="Google",
clients=[
ProviderClientConfig(
client_id="google-client-id",
client_secret="google-client-secret",
)
],
),
)

res = app.get(
"/auth/authorisationurl?thirdPartyId=google&redirectURIOnProviderDashboard=redirect"
)
body = json.loads(res.text)
assert body["status"] == "OK"
assert (
body["urlWithQueryParams"]
== "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&redirect_uri=redirect&response_type=code&scope=openid+email&included_grant_scopes=true&access_type=offline"
)
Loading

0 comments on commit 89fb919

Please sign in to comment.