Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Router changes to handle tenant id #364

Merged
merged 10 commits into from
Jul 12, 2023
23 changes: 16 additions & 7 deletions supertokens_python/recipe/dashboard/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,18 +256,27 @@ def return_api_id_if_can_handle_request(
base_path_str = self.app_info.api_base_path.get_as_string_dangerous()
path_str = path.get_as_string_dangerous()
regex = rf"^{base_path_str}(?:/([a-zA-Z0-9-]+))?(/.*)$"
# some examples against for above regex:
# books => match = None
# public/books => match = None
# /books => match.group(1) = None, match.group(2) = /dashboard
# /public/books => match.group(1) = 'public', match.group(2) = '/books'
# /public/book/1 => match.group(1) = 'public', match.group(2) = '/book/1'

match = re.match(regex, path_str)
match_group_1 = match.group(1) if match is not None else None
match_group_2 = match.group(2) if match is not None else None

tenant_id: str = DEFAULT_TENANT_ID
remaining_path: Optional[NormalisedURLPath] = None

if match is not None:
# TODO: Do something better than assert here
assert match.group(1) is not None
assert match.group(2) is not None

tenant_id = match.group(1)
remaining_path = NormalisedURLPath(match.group(2))
if (
match is not None
and isinstance(match_group_1, str)
and isinstance(match_group_2, str)
):
tenant_id = match_group_1
remaining_path = NormalisedURLPath(match_group_2)

if is_api_path(path, self.app_info.api_base_path) or (
remaining_path is not None
Expand Down
10 changes: 1 addition & 9 deletions supertokens_python/recipe/thirdparty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,14 @@


from . import exceptions as ex
from . import utils, provider, providers
from . import utils, provider
from .recipe import ThirdPartyRecipe

InputOverrideConfig = utils.InputOverrideConfig
SignInAndUpFeature = utils.SignInAndUpFeature
ProviderInput = provider.ProviderInput
ProviderConfig = provider.ProviderConfig
ProviderClientConfig = provider.ProviderClientConfig
Apple = providers.Apple
Discord = providers.Discord
Facebook = providers.Facebook
Github = providers.Github
Google = providers.Google
GoogleWorkspaces = providers.GoogleWorkspaces
Bitbucket = providers.Bitbucket
GitLab = providers.GitLab
exceptions = ex

if TYPE_CHECKING:
Expand Down
5 changes: 1 addition & 4 deletions supertokens_python/recipe/thirdparty/providers/apple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
# under the License.
from __future__ import annotations

from typing import Any, Dict, Optional
from re import sub
from typing import Any, Dict, Optional
from jwt import encode # type: ignore
from time import time

from jwt import encode

from .custom import GenericProvider, NewProvider
from .utils import get_actual_client_id_from_development_client_id
from ..provider import Provider, ProviderConfigForClientType, ProviderInput
from .utils import get_actual_client_id_from_development_client_id


class AppleImpl(GenericProvider):
Expand Down
90 changes: 15 additions & 75 deletions supertokens_python/recipe/thirdparty/providers/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,88 +19,28 @@
from ..provider import Provider, ProviderInput


class Bitbucket(Provider):
def __init__(
self,
client_id: str,
client_secret: str,
scope: Union[None, List[str]] = None,
authorisation_redirect: Union[
None, Dict[str, Union[str, Callable[[BaseRequest], str]]]
] = None,
_is_default: bool = False,
):
super().__init__("bitbucket") # FIXME: Where should is_default go?
self.client_id = client_id
self.client_secret = client_secret
self.scopes = ["account", "email"] if scope is None else list(set(scope))
self.access_token_api_url = "https://bitbucket.org/site/oauth2/access_token"
self.authorisation_redirect_url = "https://bitbucket.org/site/oauth2/authorize"
self.authorisation_redirect_params = {}
if authorisation_redirect is not None:
self.authorisation_redirect_params = authorisation_redirect
# TODO Implement when it's done in Node PR
class BitbucketImpl(GenericProvider):
pass

async def get_profile_info( # pylint: disable=no-self-use
self, auth_code_response: Dict[str, Any], _user_context: Dict[str, Any]
) -> UserInfo:
access_token: str = auth_code_response["access_token"]
headers = {"Authorization": f"Bearer {access_token}"}
async with AsyncClient() as client:
response = await client.get( # type: ignore
url="https://api.bitbucket.org/2.0/user",
headers=headers,
)
user_info = response.json()
user_id = user_info["uuid"]
email_res = await client.get( # type: ignore
url="https://api.bitbucket.org/2.0/user/emails",
headers=headers,
)
email_data = email_res.json()
email = None
is_verified = False
for email_info in email_data["values"]:
if email_info.get("is_primary"):
email = email_info["email"]
is_verified = email_info["is_confirmed"]
break

# TODO Finish when it's done in Node PR
def Bitbucket(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Bitbucket"

def get_authorisation_redirect_api_info(
self, _user_context: Dict[str, Any]
) -> AuthorisationRedirectAPI:
params = {
"scope": " ".join(self.scopes),
"response_type": "code",
"client_id": self.client_id,
"access_type": "offline",
**self.authorisation_redirect_params,
}
return AuthorisationRedirectAPI(self.authorisation_redirect_url, params)
if input.config.authorization_endpoint is None:
input.config.authorization_endpoint = (
"https://bitbucket.org/site/oauth2/authorize"
)

def get_access_token_api_info(
self,
redirect_uri: str,
auth_code_from_request: str,
_user_context: Dict[str, Any],
) -> AccessTokenAPI:
params = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code",
"code": auth_code_from_request,
"redirect_uri": redirect_uri,
}
return AccessTokenAPI(self.access_token_api_url, params)
if input.config.token_endpoint is None:
input.config.token_endpoint = "https://bitbucket.org/site/oauth2/access_token"

def get_redirect_uri( # pylint: disable=no-self-use
self, _user_context: Dict[str, Any]
) -> Union[None, str]:
return None
if input.config.user_info_endpoint is None:
input.config.user_info_endpoint = "https://api.bitbucket.org/2.0/user"

def get_client_id(self, _user_context: Dict[str, Any]) -> str:
return self.client_id
# TODO overrides and working of this
# once done in Node PR

return NewProvider(input, BitbucketImpl)
79 changes: 6 additions & 73 deletions supertokens_python/recipe/thirdparty/providers/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,78 +19,11 @@
from ..provider import Provider, ProviderInput


class GitLab(Provider):
def __init__(
self,
client_id: str,
client_secret: str,
scope: Union[None, List[str]] = None,
authorisation_redirect: Union[
None, Dict[str, Union[str, Callable[[BaseRequest], str]]]
] = None,
gitlab_base_url: str = "https://gitlab.com",
_is_default: bool = False,
):
super().__init__("gitlab") # FIXME: Where should is_default go?
default_scopes = ["read_user"]
if scope is None:
scope = default_scopes
self.client_id = client_id
self.client_secret = client_secret
self.scopes = list(set(scope))
gitlab_base_url = NormalisedURLDomain(gitlab_base_url).get_as_string_dangerous()
self.gitlab_base_url = gitlab_base_url
self.access_token_api_url = f"{gitlab_base_url}/oauth/token"
self.authorisation_redirect_url = f"{gitlab_base_url}/oauth/authorize"
self.authorisation_redirect_params = {}
if authorisation_redirect is not None:
self.authorisation_redirect_params = authorisation_redirect
# TODO Implement when it's done in Node PR
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove all todos from code and transfer them to #276

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

class GitlabImpl(GenericProvider):
pass

async def get_profile_info(
self, auth_code_response: Dict[str, Any], _user_context: Dict[str, Any]
) -> UserInfo:
access_token: str = auth_code_response["access_token"]
headers = {"Authorization": f"Bearer {access_token}"}
async with AsyncClient() as client:
response = await client.get(f"{self.gitlab_base_url}/api/v4/user", headers=headers) # type: ignore
user_info = response.json()
user_id = str(user_info["id"])
email = user_info.get("email")
if email is None:
return UserInfo(user_id)
is_email_verified = user_info.get("confirmed_at") is not None
return UserInfo(user_id, UserInfoEmail(email, is_email_verified))

def get_authorisation_redirect_api_info(
self, _user_context: Dict[str, Any]
) -> AuthorisationRedirectAPI:
params = {
"scope": " ".join(self.scopes),
"response_type": "code",
"client_id": self.client_id,
**self.authorisation_redirect_params,
}
return AuthorisationRedirectAPI(self.authorisation_redirect_url, params)

def get_access_token_api_info(
self,
redirect_uri: str,
auth_code_from_request: str,
_user_context: Dict[str, Any],
) -> AccessTokenAPI:
params = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code",
"code": auth_code_from_request,
"redirect_uri": redirect_uri,
}
return AccessTokenAPI(self.access_token_api_url, params)

def get_redirect_uri( # pylint: disable=no-self-use
self, _user_context: Dict[str, Any]
) -> Union[None, str]:
return None

def get_client_id(self, _user_context: Dict[str, Any]) -> str:
return self.client_id
# TODO Implement when it's done in Node PR
def Gitlab(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
return NewProvider(input, GitlabImpl)
8 changes: 0 additions & 8 deletions supertokens_python/recipe/thirdpartyemailpassword/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@
ProviderConfig = provider.ProviderConfig
ProviderClientConfig = provider.ProviderClientConfig
ProviderConfigForClientType = provider.ProviderConfigForClientType
Apple = thirdparty.Apple
Discord = thirdparty.Discord
Facebook = thirdparty.Facebook
Github = thirdparty.Github
Google = thirdparty.Google
GoogleWorkspaces = thirdparty.GoogleWorkspaces
Bitbucket = thirdparty.Bitbucket
GitLab = thirdparty.GitLab
SMTPService = emaildelivery_services.SMTPService

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from supertokens_python.recipe.thirdparty.interfaces import GetProviderOkResult
from supertokens_python.recipe.thirdparty.provider import ProviderInput
from supertokens_python.recipe.thirdparty.types import RawUserInfoFromProvider
from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig

if TYPE_CHECKING:
from supertokens_python.querier import Querier
Expand Down
17 changes: 10 additions & 7 deletions supertokens_python/recipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@ def return_api_id_if_can_handle_request(
regex = rf"^{base_path_str}(?:/([a-zA-Z0-9-]+))?(/.*)$"

match = re.match(regex, path_str)
match_group_1 = match.group(1) if match is not None else None
match_group_2 = match.group(2) if match is not None else None

tenant_id: str = DEFAULT_TENANT_ID
remaining_path: Optional[NormalisedURLPath] = None

if match is not None:
# TODO: Do something better than assert here
# assert match.group(1) is not None
# assert match.group(2) is not None

tenant_id = match.group(1)
remaining_path = NormalisedURLPath(match.group(2))
if (
match is not None
and isinstance(match_group_1, str)
and isinstance(match_group_2, str)
):
tenant_id = match_group_1
remaining_path = NormalisedURLPath(match_group_2)

for current_api in apis_handled:
if not current_api.disabled and current_api.method == method:
Expand Down
5 changes: 4 additions & 1 deletion supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule:
self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list))

if callable(DEFAULT_MULTITENANCY_RECIPE) and not multitenancy_found[0]:
self.recipe_modules.append(DEFAULT_MULTITENANCY_RECIPE(self.app_info))
recipe = DEFAULT_MULTITENANCY_RECIPE( # pylint: disable=not-callable
self.app_info
)
self.recipe_modules.append(recipe)

self.telemetry = (
telemetry
Expand Down
28 changes: 8 additions & 20 deletions tests/multitenancy/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,13 @@ async def client():
app = FastAPI()
app.add_middleware(get_middleware())

# @app.get("/login")
# async def login(_request: Request): # type: ignore
# user_id = "userId"
# # await create_new_session(request, user_id, {}, {})
# return {"userId": user_id}

return TestClient(app)


async def test_emailpassword_router(client: TestClient):
args = get_st_init_args(
[
session.init(get_token_transfer_method=lambda *_: "cookie"),
session.init(get_token_transfer_method=lambda *_: "cookie"), # type: ignore
emailpassword.init(),
]
)
Expand Down Expand Up @@ -82,7 +76,7 @@ async def test_emailpassword_router(client: TestClient):
async def test_dashboard_apis_router(client: TestClient):
args = get_st_init_args(
[
session.init(get_token_transfer_method=lambda *_: "cookie"),
session.init(get_token_transfer_method=lambda *_: "cookie"), # type: ignore
emailpassword.init(),
dashboard.init(),
]
Expand All @@ -94,26 +88,20 @@ async def test_dashboard_apis_router(client: TestClient):
"/auth/public/dashboard/api/signin",
headers={"Content-Type": "application/json"},
json={
"formFields": [
{"id": "password", "value": "password1"},
{"id": "email", "value": "test1@example.com"},
]
"email": "test1@example.com",
"password": "password1",
},
)

assert res.status_code == 200 # FIXME: failing test
assert res.json()["status"] == "OK"
assert res.status_code == 200

res = client.post(
"/auth/dashboard/api/signin",
headers={"Content-Type": "application/json"},
json={
"formFields": [
{"id": "password", "value": "password1"},
{"id": "email", "value": "test1@example.com"},
]
"email": "test1@example.com",
"password": "password1",
},
)

assert res.status_code == 200 # FIXME: failing test
assert res.json()["status"] == "OK"
assert res.status_code == 200
Loading
Loading