Skip to content

Commit

Permalink
fix: Types and linter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
KShivendu committed Jul 4, 2023
1 parent 29e5023 commit 4e0d6a0
Show file tree
Hide file tree
Showing 20 changed files with 90 additions and 71 deletions.
7 changes: 5 additions & 2 deletions supertokens_python/always_initialised_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any, Callable, Union
from typing import Callable, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from supertokens_python.recipe_module import RecipeModule
from supertokens_python import AppInfo

DEFAULT_MULTITENANCY_RECIPE: Union[Callable[[Any], Any], None] = None
DEFAULT_MULTITENANCY_RECIPE: Optional[Callable[[AppInfo], RecipeModule]] = None
4 changes: 3 additions & 1 deletion supertokens_python/recipe/multitenancy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def __init__(


class ThirdPartyProvider:
def __init__(self, id: str, name: Optional[str]):
def __init__(
self, id: str, name: Optional[str]
): # pylint: disable=redefined-builtin
self.id = id
self.name = name

Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/multitenancy/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def login_methods_get(

class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]):
def __init__(self):
async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]:
async def fetch_value(_user_id: str, user_context: Dict[str, Any]) -> List[str]:
recipe = MultitenancyRecipe.get_instance()
tenant_id = (
None # TODO fetch value will be passed with tenant_id as well later
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


from . import exceptions as ex
from . import utils, provider
from . import utils, provider, providers
from .recipe import ThirdPartyRecipe

InputOverrideConfig = utils.InputOverrideConfig
Expand Down
18 changes: 9 additions & 9 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,43 @@ def __init__(


class Provider:
def __init__(self, id: str):
def __init__(self, id: str): # pylint: disable=redefined-builtin
self.id = id
self.config = ProviderConfigForClientType("temp")

async def get_config_for_client_type(
async def get_config_for_client_type( # pylint: disable=no-self-use
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClientType:
_ = client_type
__ = user_context
raise NotImplementedError
raise NotImplementedError()

async def get_authorisation_redirect_url(
async def get_authorisation_redirect_url( # pylint: disable=no-self-use
self,
redirect_uri_on_provider_dashboard: str,
user_context: Dict[str, Any],
) -> AuthorisationRedirect:
_ = redirect_uri_on_provider_dashboard
__ = user_context
raise NotImplementedError
raise NotImplementedError()

async def exchange_auth_code_for_oauth_tokens(
async def exchange_auth_code_for_oauth_tokens( # pylint: disable=no-self-use
self,
redirect_uri_info: RedirectUriInfo,
user_context: Dict[str, Any],
) -> Dict[str, Any]:
_ = redirect_uri_info
__ = user_context
raise NotImplementedError
raise NotImplementedError()

async def get_user_info(
async def get_user_info( # pylint: disable=no-self-use
self,
oauth_tokens: Dict[str, Any],
user_context: Dict[str, Any],
) -> UserInfo:
_ = oauth_tokens
__ = user_context
raise NotImplementedError
raise NotImplementedError()


class ProviderClientConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ async def get_config_for_client_type(
return config


def ActiveDirectory(input: ProviderInput) -> Provider:
def ActiveDirectory(
input: ProviderInput, # pylint: disable=redefined-builtin
) -> Provider:
if input.config.name is None:
input.config.name = "Active Directory"

Expand Down
16 changes: 9 additions & 7 deletions supertokens_python/recipe/thirdparty/providers/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
] = None,
is_default: bool = False,
):
super().__init__("bitbucket", is_default)
super().__init__("bitbucket") # FIXME: Where should is_default go?
self.client_id = client_id
self.client_secret = client_secret
self.scopes = ["account", "email"] if scope is None else list(set(scope))
Expand All @@ -50,8 +50,8 @@ def __init__(
if authorisation_redirect is not None:
self.authorisation_redirect_params = authorisation_redirect

async def get_profile_info(
self, auth_code_response: Dict[str, Any], user_context: Dict[str, Any]
async def get_profile_info( # pylint: disable=no-self-use
self, auth_code_response: Dict[str, Any], _user_context: Dict[str, Any]
) -> UserInfo:
access_token: str = auth_code_response["access_token"]
headers = {"Authorization": f"Bearer {access_token}"}
Expand Down Expand Up @@ -80,7 +80,7 @@ async def get_profile_info(
return UserInfo(user_id, UserInfoEmail(email, is_verified))

def get_authorisation_redirect_api_info(
self, user_context: Dict[str, Any]
self, _user_context: Dict[str, Any]
) -> AuthorisationRedirectAPI:
params = {
"scope": " ".join(self.scopes),
Expand All @@ -95,7 +95,7 @@ def get_access_token_api_info(
self,
redirect_uri: str,
auth_code_from_request: str,
user_context: Dict[str, Any],
_user_context: Dict[str, Any],
) -> AccessTokenAPI:
params = {
"client_id": self.client_id,
Expand All @@ -106,8 +106,10 @@ def get_access_token_api_info(
}
return AccessTokenAPI(self.access_token_api_url, params)

def get_redirect_uri(self, user_context: Dict[str, Any]) -> Union[None, str]:
def get_redirect_uri( # pylint: disable=no-self-use
self, _user_context: Dict[str, Any]
) -> Union[None, str]:
return None

def get_client_id(self, user_context: Dict[str, Any]) -> str:
def get_client_id(self, _user_context: Dict[str, Any]) -> str:
return self.client_id
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def get_config_for_client_type(
return config


def BoxySAML(input: ProviderInput) -> Provider:
def BoxySAML(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Boxy SAML"

Expand Down
35 changes: 16 additions & 19 deletions supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ async def verify_id_token_from_jwks_endpoint_and_get_payload(
raise err


def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
res = dest.copy()
for k, v in src.items():
if v is None:
if k in res:
del res[k]
else:
res[k] = v

return res


class GenericProvider(Provider):
def __init__(self, config: ProviderConfig):
super().__init__(config.third_party_id)
Expand Down Expand Up @@ -288,12 +300,7 @@ async def exchange_auth_code_for_oauth_tokens(
access_token_params["code_verifier"] = redirect_uri_info.pkce_code_verifier

if self.config.token_endpoint_body_params is not None:
for k, v in self.config.token_endpoint_body_params:
if v is None:
if k in access_token_params:
del access_token_params[k]
else:
access_token_params[k] = v
access_token_params = merge_into_dict(self.config.token_endpoint_body_params, access_token_params)

# Transformation needed for dev keys BEGIN
if is_using_oauth_development_client_id(self.config.client_id):
Expand Down Expand Up @@ -336,20 +343,10 @@ async def get_user_info(

if self.config.user_info_endpoint is not None:
if self.config.user_info_endpoint_headers is not None:
for k, v in self.config.user_info_endpoint_headers.items():
if v is None:
if k in headers:
del headers[k]
else:
headers[k] = v
headers = merge_into_dict(self.config.user_info_endpoint_headers, headers)

if self.config.user_info_endpoint_query_params is not None:
for k, v in self.config.user_info_endpoint_query_params.items():
if v is None:
if k in query_params:
del query_params[k]
else:
query_params[k] = v
query_params = merge_into_dict(self.config.user_info_endpoint_query_params, query_params)

raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
Expand All @@ -367,7 +364,7 @@ async def get_user_info(


def NewProvider(
input: ProviderInput,
input: ProviderInput, # pylint: disable=redefined-builtin
base_class: Callable[[ProviderConfig], Provider] = GenericProvider,
) -> Provider:
provider_instance = base_class(input.config)
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def get_config_for_client_type(
return config


def Discord(input: ProviderInput) -> Provider:
def Discord(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Discord"

Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def get_user_info(
return await super().get_user_info(oauth_tokens, user_context)


def Facebook(input: ProviderInput) -> Provider:
def Facebook(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Facebook"

Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def get_user_info(
return result


def Github(input: ProviderInput) -> Provider:
def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Github"

Expand Down
14 changes: 8 additions & 6 deletions supertokens_python/recipe/thirdparty/providers/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
gitlab_base_url: str = "https://gitlab.com",
is_default: bool = False,
):
super().__init__("gitlab", is_default)
super().__init__("gitlab") # FIXME: Where should is_default go?
default_scopes = ["read_user"]
if scope is None:
scope = default_scopes
Expand All @@ -59,7 +59,7 @@ def __init__(
self.authorisation_redirect_params = authorisation_redirect

async def get_profile_info(
self, auth_code_response: Dict[str, Any], user_context: Dict[str, Any]
self, auth_code_response: Dict[str, Any], _user_context: Dict[str, Any]
) -> UserInfo:
access_token: str = auth_code_response["access_token"]
headers = {"Authorization": f"Bearer {access_token}"}
Expand All @@ -74,7 +74,7 @@ async def get_profile_info(
return UserInfo(user_id, UserInfoEmail(email, is_email_verified))

def get_authorisation_redirect_api_info(
self, user_context: Dict[str, Any]
self, _user_context: Dict[str, Any]
) -> AuthorisationRedirectAPI:
params = {
"scope": " ".join(self.scopes),
Expand All @@ -88,7 +88,7 @@ def get_access_token_api_info(
self,
redirect_uri: str,
auth_code_from_request: str,
user_context: Dict[str, Any],
_user_context: Dict[str, Any],
) -> AccessTokenAPI:
params = {
"client_id": self.client_id,
Expand All @@ -99,8 +99,10 @@ def get_access_token_api_info(
}
return AccessTokenAPI(self.access_token_api_url, params)

def get_redirect_uri(self, user_context: Dict[str, Any]) -> Union[None, str]:
def get_redirect_uri( # pylint: disable=no-self-use
self, _user_context: Dict[str, Any]
) -> Union[None, str]:
return None

def get_client_id(self, user_context: Dict[str, Any]) -> str:
def get_client_id(self, _user_context: Dict[str, Any]) -> str:
return self.client_id
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def get_config_for_client_type(


def Google(
input: ProviderInput,
input: ProviderInput, # pylint: disable=redefined-builtin
base_class: Callable[[ProviderConfig], GoogleImpl] = GoogleImpl,
) -> Provider:
if input.config.name is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ async def get_config_for_client_type(
return config


def GoogleWorkspaces(input: ProviderInput) -> Provider:
def GoogleWorkspaces(
input: ProviderInput, # pylint: disable=redefined-builtin
) -> Provider:
if input.config.name is None:
input.config.name = "Google Workspaces"

Expand All @@ -49,7 +51,7 @@ def GoogleWorkspaces(input: ProviderInput) -> Provider:
async def default_validate_id_token_payload(
id_token_payload: Dict[str, Any],
config: ProviderConfigForClientType,
user_context: Dict[str, Any],
_user_context: Dict[str, Any],
):
if (config.additional_config or {}).get("hd", "*") != "*":
if (config.additional_config or {}).get("hd") != id_token_payload.get(
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/linkedin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def get_user_info(
)


def Linkedin(input: ProviderInput) -> Provider:
def Linkedin(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Linkedin"

Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/thirdparty/providers/okta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def get_config_for_client_type(
return config


def Okta(input: ProviderInput) -> Provider:
def Okta(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if input.config.name is None:
input.config.name = "Okta"

Expand Down
Loading

0 comments on commit 4e0d6a0

Please sign in to comment.