Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Create a PasswordProvider wrapper object (#8849)
Browse files Browse the repository at this point in the history
The idea here is to abstract out all the conditional code which tests which
methods a given password provider has, to provide a consistent interface.
  • Loading branch information
richvdh authored Dec 2, 2020
1 parent edb3d3f commit d3ed935
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 57 deletions.
1 change: 1 addition & 0 deletions changelog.d/8849.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `password_auth_provider` support code.
203 changes: 148 additions & 55 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,6 +26,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -181,17 +183,12 @@ def __init__(self, hs: "HomeServer"):
# better way to break the loop
account_handler = ModuleApi(hs, self)

self.password_providers = []
for module, config in hs.config.password_providers:
try:
self.password_providers.append(
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]

logger.info("Extra password_providers: %r", self.password_providers)
logger.info("Extra password_providers: %s", self.password_providers)

self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
Expand Down Expand Up @@ -853,6 +850,8 @@ async def validate_login(
LoginError if there was an authentication problem.
"""
login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)

# ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
Expand Down Expand Up @@ -998,24 +997,12 @@ async def _validate_userid_login(
qualified_user_id = UserID(username, self.hs.hostname).to_string()

login_type = login_submission.get("type")
# we already checked that we have a valid login type
assert isinstance(login_type, str)

known_login_type = False

for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
# we've already checked that there is a (valid) password field
is_valid = await provider.check_password(
qualified_user_id, login_submission["password"]
)
if is_valid:
return qualified_user_id, None

if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue

supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
Expand All @@ -1040,8 +1027,6 @@ async def _validate_userid_login(

result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result

if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
Expand Down Expand Up @@ -1083,19 +1068,9 @@ async def check_password_provider_3pid(
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result

return None, None

Expand Down Expand Up @@ -1153,16 +1128,11 @@ async def delete_access_token(self, access_token: str):

# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)

# delete pushers associated with this access token
if user_info.token_id is not None:
Expand Down Expand Up @@ -1191,11 +1161,10 @@ async def delete_access_tokens_for_user(

# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)

# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
Expand Down Expand Up @@ -1519,3 +1488,127 @@ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon


class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""

@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)

def __init__(self, pp, module_api: ModuleApi):
self._pp = pp
self._module_api = module_api

self._supported_login_types = {}

# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)

g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())

def __str__(self):
return str(self._pp)

def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types

async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None

g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)

# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None

return result

async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None

# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)

# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None

return result

async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return

# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result
5 changes: 3 additions & 2 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,9 @@ def test_no_local_user_fallback_ui_auth(self):
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# there are no valid flows here!
self.assertEqual(channel.json_body["flows"], [])
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()

Expand Down

0 comments on commit d3ed935

Please sign in to comment.