diff --git a/changelog.d/10548.feature b/changelog.d/10548.feature new file mode 100644 index 000000000000..263a811faf16 --- /dev/null +++ b/changelog.d/10548.feature @@ -0,0 +1 @@ +Port the Password Auth Providers module interface to the new generic interface. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index bdb44543b83d..35412ea92c0c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -43,6 +43,7 @@ - [Third-party rules callbacks](modules/third_party_rules_callbacks.md) - [Presence router callbacks](modules/presence_router_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md) + - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md new file mode 100644 index 000000000000..36417dd39e20 --- /dev/null +++ b/docs/modules/password_auth_provider_callbacks.md @@ -0,0 +1,153 @@ +# Password auth provider callbacks + +Password auth providers offer a way for server administrators to integrate +their Synapse installation with an external authentication system. The callbacks can be +registered by using the Module API's `register_password_auth_provider_callbacks` method. + +## Callbacks + +### `auth_checkers` + +``` + auth_checkers: Dict[Tuple[str,Tuple], Callable] +``` + +A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a +tuple of field names (such as `("password", "secret_thing")`) to authentication checking +callbacks, which should be of the following form: + +```python +async def check_auth( + user: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +The login type and field names should be provided by the user in the +request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types) +defines some types, however user defined ones are also allowed. + +The callback is passed the `user` field provided by the client (which might not be in +`@username:server` form), the login type, and a dictionary of login secrets passed by +the client. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the +`/login` request. If the module doesn't wish to return a callback, it must return `None` +instead. + +If the authentication is unsuccessful, the module must return `None`. + +### `check_3pid_auth` + +```python +async def check_3pid_auth( + medium: str, + address: str, + password: str, +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +Called when a user attempts to register or log in with a third party identifier, +such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`) +and the user's password. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request. +If the module doesn't wish to return a callback, it must return None instead. + +If the authentication is unsuccessful, the module must return None. + +### `on_logged_out` + +```python +async def on_logged_out( + user_id: str, + device_id: Optional[str], + access_token: str +) -> None +``` +Called during a logout request for a user. It is passed the qualified user ID, the ID of the +deactivated device (if any: access tokens are occasionally created without an associated +device ID), and the (now deactivated) access token. + +## Example + +The example module below implements authentication checkers for two different login types: +- `my.login.type` + - Expects a `my_field` field to be sent to `/login` + - Is checked by the method: `self.check_my_login` +- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) + - Expects a `password` field to be sent to `/login` + - Is checked by the method: `self.check_pass` + + +```python +from typing import Awaitable, Callable, Optional, Tuple + +import synapse +from synapse import module_api + + +class MyAuthProvider: + def __init__(self, config: dict, api: module_api): + + self.api = api + + self.credentials = { + "bob": "building", + "@scoop:matrix.org": "digging", + } + + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("my.login_type", ("my_field",)): self.check_my_login, + ("m.login.password", ("password",)): self.check_pass, + }, + ) + + async def check_my_login( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "my.login_type": + return None + + if self.credentials.get(username) == login_dict.get("my_field"): + return self.api.get_qualified_user_id(username) + + async def check_pass( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "m.login.password": + return None + + if self.credentials.get(username) == login_dict.get("password"): + return self.api.get_qualified_user_id(username) +``` diff --git a/docs/modules/porting_legacy_module.md b/docs/modules/porting_legacy_module.md index a7a251e53580..89084eb7b32b 100644 --- a/docs/modules/porting_legacy_module.md +++ b/docs/modules/porting_legacy_module.md @@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for more info). +There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any +changes to the database should now be made by the module using the module API class. + The module's author should also update any example in the module's configuration to only use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules) for more info). diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index d2cdb9b2f4a3..d7beacfff3e9 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -1,3 +1,9 @@ +

+This page of the Synapse documentation is now deprecated. For up to date +documentation on setting up or writing a password auth provider module, please see +this page. +

+ # Password auth provider modules Password auth providers offer a way for server administrators to diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 166cec38d3f2..7bfaed483b61 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2260,34 +2260,6 @@ email: #email_validation: "[%(server_name)s] Validate your email" -# Password providers allow homeserver administrators to integrate -# their Synapse installation with existing authentication methods -# ex. LDAP, external tokens, etc. -# -# For more information and known implementations, please see -# https://matrix-org.github.io/synapse/latest/password_auth_providers.html -# -# Note: instances wishing to use SAML or CAS authentication should -# instead use the `saml2_config` or `cas_config` options, -# respectively. -# -password_providers: -# # Example config for an LDAP auth provider -# - module: "ldap_auth_provider.LdapAuthProvider" -# config: -# enabled: true -# uri: "ldap://ldap.example.com:389" -# start_tls: true -# base: "ou=users,dc=example,dc=com" -# attributes: -# uid: "cn" -# mail: "email" -# name: "givenName" -# #bind_dn: -# #bind_password: -# #filter: "(objectClass=posixAccount)" - - ## Push ## diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4a204a582373..bb4d53d77891 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -42,6 +42,7 @@ from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules +from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -379,6 +380,7 @@ def run_sighup(*args, **kwargs): load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) + load_legacy_password_auth_providers(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83994df798bd..f980102b45e2 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): + """Parses the old password auth providers config. The config format looks like this: + + password_providers: + # Example config for an LDAP auth provider + - module: "ldap_auth_provider.LdapAuthProvider" + config: + enabled: true + uri: "ldap://ldap.example.com:389" + start_tls: true + base: "ou=users,dc=example,dc=com" + attributes: + uid: "cn" + mail: "email" + name: "givenName" + #bind_dn: + #bind_password: + #filter: "(objectClass=posixAccount)" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ + self.password_providers: List[Tuple[Type, Any]] = [] providers = [] @@ -49,33 +72,3 @@ def read_config(self, config, **kwargs): ) self.password_providers.append((provider_class, provider_config)) - - def generate_config_section(self, **kwargs): - return """\ - # Password providers allow homeserver administrators to integrate - # their Synapse installation with existing authentication methods - # ex. LDAP, external tokens, etc. - # - # For more information and known implementations, please see - # https://matrix-org.github.io/synapse/latest/password_auth_providers.html - # - # Note: instances wishing to use SAML or CAS authentication should - # instead use the `saml2_config` or `cas_config` options, - # respectively. - # - password_providers: - # # Example config for an LDAP auth provider - # - module: "ldap_auth_provider.LdapAuthProvider" - # config: - # enabled: true - # uri: "ldap://ldap.example.com:389" - # start_tls: true - # base: "ou=users,dc=example,dc=com" - # attributes: - # uid: "cn" - # mail: "email" - # name: "givenName" - # #bind_dn: - # #bind_password: - # #filter: "(objectClass=posixAccount)" - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4612a5b9223..ebe75a9e9b22 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -200,46 +200,13 @@ def __init__(self, hs: "HomeServer"): self.bcrypt_rounds = hs.config.registration.bcrypt_rounds - # we can't use hs.get_module_api() here, because to do so will create an - # import loop. - # - # TODO: refactor this class to separate the lower-level stuff that - # ModuleApi can use from the higher-level stuff that uses ModuleApi, as - # better way to break the loop - account_handler = ModuleApi(hs, self) - - self.password_providers = [ - PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.authproviders.password_providers - ] - - logger.info("Extra password_providers: %s", self.password_providers) + self.password_auth_provider = hs.get_password_auth_provider() self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = set() - if self._password_localdb_enabled: - login_types.add(LoginType.PASSWORD) - - for provider in self.password_providers: - login_types.update(provider.get_supported_login_types().keys()) - - if not self._password_enabled: - login_types.discard(LoginType.PASSWORD) - - # Some clients just pick the first type in the list. In this case, we want - # them to use PASSWORD (rather than token or whatever), so we want to make sure - # that comes first, where it's present. - self._supported_login_types = [] - if LoginType.PASSWORD in login_types: - self._supported_login_types.append(LoginType.PASSWORD) - login_types.remove(LoginType.PASSWORD) - self._supported_login_types.extend(login_types) - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -427,11 +394,10 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: ui_auth_types.add(LoginType.PASSWORD) # also allow auth from password providers - for provider in self.password_providers: - for t in provider.get_supported_login_types().keys(): - if t == LoginType.PASSWORD and not self._password_enabled: - continue - ui_auth_types.add(t) + for t in self.password_auth_provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. @@ -1038,7 +1004,25 @@ def get_supported_login_types(self) -> Iterable[str]: Returns: login types """ - return self._supported_login_types + # Load any login types registered by modules + # This is stored in the password_auth_provider so this doesn't trigger + # any callbacks + types = list(self.password_auth_provider.get_supported_login_types().keys()) + + # This list should include PASSWORD if (either _password_localdb_enabled is + # true or if one of the modules registered it) AND _password_enabled is true + # Also: + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + if LoginType.PASSWORD in types: + types.remove(LoginType.PASSWORD) + if self._password_enabled: + types.insert(0, LoginType.PASSWORD) + elif self._password_localdb_enabled and self._password_enabled: + types.insert(0, LoginType.PASSWORD) + + return types async def validate_login( self, @@ -1217,15 +1201,20 @@ async def _validate_userid_login( known_login_type = False - for provider in self.password_providers: - supported_login_types = provider.get_supported_login_types() - if login_type not in supported_login_types: - # this password provider doesn't understand this login type - continue - + # Check if login_type matches a type registered by one of the modules + # We don't need to remove LoginType.PASSWORD from the list if password login is + # disabled, since if that were the case then by this point we know that the + # login_type is not LoginType.PASSWORD + supported_login_types = self.password_auth_provider.get_supported_login_types() + # check if the login type being used is supported by a module + if login_type in supported_login_types: + # Make a note that this login type is supported by the server known_login_type = True + # Get all the fields expected for this login types login_fields = supported_login_types[login_type] + # go through the login submission and keep track of which required fields are + # provided/not provided missing_fields = [] login_dict = {} for f in login_fields: @@ -1233,6 +1222,7 @@ async def _validate_userid_login( missing_fields.append(f) else: login_dict[f] = login_submission[f] + # raise an error if any of the expected fields for that login type weren't provided if missing_fields: raise SynapseError( 400, @@ -1240,10 +1230,15 @@ async def _validate_userid_login( % (login_type, missing_fields), ) - result = await provider.check_auth(username, login_type, login_dict) + # call all of the check_auth hooks for that login_type + # it will return a result once the first success is found (or None otherwise) + result = await self.password_auth_provider.check_auth( + username, login_type, login_dict + ) if result: return result + # if no module managed to authenticate the user, then fallback to built in password based auth if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True @@ -1282,11 +1277,16 @@ async def check_password_provider_3pid( completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. """ - for provider in self.password_providers: - result = await provider.check_3pid_auth(medium, address, password) - if result: - return result + # call all of the check_3pid_auth callbacks + # Result will be from the first callback that returns something other than None + # If all the callbacks return None, then result is also set to None + result = await self.password_auth_provider.check_3pid_auth( + medium, address, password + ) + if result: + return result + # if result is None then return (None, None) return None, None async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: @@ -1365,13 +1365,12 @@ async def delete_access_token(self, access_token: str) -> None: user_info = await self.auth.get_user_by_access_token(access_token) await self.store.delete_access_token(access_token) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - await provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) + # see if any modules want to know about this + await self.password_auth_provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1398,12 +1397,11 @@ async def delete_access_tokens_for_user( user_id, except_token_id=except_token_id, device_id=device_id ) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - for token, _, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + # see if any modules want to know about this + for token, _, device_id in tokens_and_devices: + await self.password_auth_provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1811,40 +1809,228 @@ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: return macaroon -class PasswordProvider: - """Wrapper for a password auth provider module +def load_legacy_password_auth_providers(hs: "HomeServer") -> None: + module_api = hs.get_module_api() + for module, config in hs.config.authproviders.password_providers: + load_single_legacy_password_auth_provider( + module=module, config=config, api=module_api + ) - This class abstracts out all of the backwards-compatibility hacks for - password providers, to provide a consistent interface. - """ - @classmethod - def load( - cls, module: Type, config: JsonDict, module_api: ModuleApi - ) -> "PasswordProvider": - try: - pp = module(config=config, account_handler=module_api) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise - return cls(pp, module_api) +def load_single_legacy_password_auth_provider( + module: Type, config: JsonDict, api: ModuleApi +) -> None: + try: + provider = module(config=config, account_handler=api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + + # The known hooks. If a module implements a method who's name appears in this set + # we'll want to register it + password_auth_provider_methods = { + "check_3pid_auth", + "on_logged_out", + } + + # All methods that the module provides should be async, but this wasn't enforced + # in the old module system, so we wrap them if needed + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We need to wrap check_password because its old form would return a boolean + # but we now want it to behave just like check_auth() and return the matrix id of + # the user if authentication succeeded or None otherwise + if f.__name__ == "check_password": + + async def wrapped_check_password( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + matrix_user_id = api.get_qualified_user_id(username) + password = login_dict["password"] + + is_valid = await f(matrix_user_id, password) + + if is_valid: + return matrix_user_id, None + + return None - def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): - self._pp = pp - self._module_api = module_api + return wrapped_check_password + + # We need to wrap check_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_auth": + + async def wrapped_check_auth( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(username, login_type, login_dict) + + if isinstance(result, str): + return result, None + + return result + + return wrapped_check_auth + + # We need to wrap check_3pid_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_3pid_auth": + + async def wrapped_check_3pid_auth( + medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(medium, address, password) + + if isinstance(result, str): + return result, None + + return result - self._supported_login_types = {} + return wrapped_check_3pid_auth - # grandfather in check_password support - if hasattr(self._pp, "check_password"): - self._supported_login_types[LoginType.PASSWORD] = ("password",) + def run(*args: Tuple, **kwargs: Dict) -> Awaitable: + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None - g = getattr(self._pp, "get_supported_login_types", None) - if g: - self._supported_login_types.update(g()) + return maybe_awaitable(f(*args, **kwargs)) - def __str__(self) -> str: - return str(self._pp) + return run + + # populate hooks with the implemented methods, wrapped with async_wrapper + hooks = { + hook: async_wrapper(getattr(provider, hook, None)) + for hook in password_auth_provider_methods + } + + supported_login_types = {} + # call get_supported_login_types and add that to the dict + g = getattr(provider, "get_supported_login_types", None) + if g is not None: + # Note the old module style also called get_supported_login_types at loading time + # and it is synchronous + supported_login_types.update(g()) + + auth_checkers = {} + # Legacy modules have a check_auth method which expects to be called with one of + # the keys returned by get_supported_login_types. New style modules register a + # dictionary of login_type->check_auth_method mappings + check_auth = async_wrapper(getattr(provider, "check_auth", None)) + if check_auth is not None: + for login_type, fields in supported_login_types.items(): + # need tuple(fields) since fields can be any Iterable type (so may not be hashable) + auth_checkers[(login_type, tuple(fields))] = check_auth + + # if it has a "check_password" method then it should handle all auth checks + # with login type of LoginType.PASSWORD + check_password = async_wrapper(getattr(provider, "check_password", None)) + if check_password is not None: + # need to use a tuple here for ("password",) not a list since lists aren't hashable + auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password + + api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + + +CHECK_3PID_AUTH_CALLBACK = Callable[ + [str, str, str], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +CHECK_AUTH_CALLBACK = Callable[ + [str, str, JsonDict], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] + + +class PasswordAuthProvider: + """ + A class that the AuthHandler calls when authenticating users + It allows modules to provide alternative methods for authentication + """ + + def __init__(self) -> None: + # lists of callbacks + self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] + self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + + # Mapping from login type to login parameters + self._supported_login_types: Dict[str, Iterable[str]] = {} + + # Mapping from login type to auth checker callbacks + self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} + + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + ) -> None: + # Register check_3pid_auth callback + if check_3pid_auth is not None: + self.check_3pid_auth_callbacks.append(check_3pid_auth) + + # register on_logged_out callback + if on_logged_out is not None: + self.on_logged_out_callbacks.append(on_logged_out) + + if auth_checkers is not None: + # register a new supported login_type + # Iterate through all of the types being registered + for (login_type, fields), callback in auth_checkers.items(): + # Note: fields may be empty here. This would allow a modules auth checker to + # be called with just 'login_type' and no password or other secrets + + # Need to check that all the field names are strings or may get nasty errors later + for f in fields: + if not isinstance(f, str): + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but all parameter names must be strings" + % (login_type, fields) + ) + + # 2 modules supporting the same login type must expect the same fields + # e.g. 1 can't expect "pass" if the other expects "password" + # so throw an exception if that happens + if login_type not in self._supported_login_types.get(login_type, []): + self._supported_login_types[login_type] = fields + else: + fields_currently_supported = self._supported_login_types.get( + login_type + ) + if fields_currently_supported != fields: + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but another module had already registered support for that type with parameters %s" + % (login_type, fields, fields_currently_supported) + ) + + # Add the new method to the list of auth_checker_callbacks for this login type + self.auth_checker_callbacks.setdefault(login_type, []).append(callback) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -1852,20 +2038,15 @@ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: Returns a map from a login type identifier (such as m.login.password) to an iterable giving the fields which must be provided by the user in the submission to the /login API. - - This wrapper adds m.login.password to the list if the underlying password - provider supports the check_password() api. """ + return self._supported_login_types async def check_auth( self, username: str, login_type: str, login_dict: JsonDict - ) -> Optional[Tuple[str, Optional[Callable]]]: + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: """Check if the user has presented valid login credentials - This wrapper also calls check_password() if the underlying password provider - supports the check_password() api and the login type is m.login.password. - Args: username: user id presented by the client. Either an MXID or an unqualified username. @@ -1879,63 +2060,130 @@ async def check_auth( user, and `callback` is an optional callback which will be called with the result from the /login call (including access_token, device_id, etc.) """ - # first grandfather in a call to check_password - if login_type == LoginType.PASSWORD: - check_password = getattr(self._pp, "check_password", None) - if check_password: - qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await check_password( - qualified_user_id, login_dict["password"] - ) - if is_valid: - return qualified_user_id, None - check_auth = getattr(self._pp, "check_auth", None) - if not check_auth: - return None - result = await check_auth(username, login_type, login_dict) + # Go through all callbacks for the login type until one returns with a value + # other than None (i.e. until a callback returns a success) + for callback in self.auth_checker_callbacks[login_type]: + try: + result = await callback(username, login_type, login_dict) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue - return result + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def check_3pid_auth( self, medium: str, address: str, password: str - ) -> Optional[Tuple[str, Optional[Callable]]]: - g = getattr(self._pp, "check_3pid_auth", None) - if not g: - return None - + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: # This function is able to return a deferred that either # resolves None, meaning authentication failure, or upon # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = await g(medium, address, password) - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + for callback in self.check_3pid_auth_callbacks: + try: + result = await callback(medium, address, password) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - return result + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - g = getattr(self._pp, "on_logged_out", None) - if not g: - return - # This might return an awaitable, if it does block the log out - # until it completes. - await maybe_awaitable( - g( - user_id=user_id, - device_id=device_id, - access_token=access_token, - ) - ) + # call all of the on_logged_out callbacks + for callback in self.on_logged_out_callbacks: + try: + callback(user_id, device_id, access_token) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b2a228c23178..ab7ef8f950bd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,6 +45,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.rest.client.login import LoginResponse from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -83,6 +84,8 @@ "DirectServeJsonResource", "ModuleApi", "PRESENCE_ALL_USERS", + "LoginResponse", + "JsonDict", ] logger = logging.getLogger(__name__) @@ -139,6 +142,7 @@ def __init__(self, hs: "HomeServer", auth_handler): self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() + self._password_auth_provider = hs.get_password_auth_provider() self._presence_router = hs.get_presence_router() ################################################################################# @@ -164,6 +168,11 @@ def register_presence_router_callbacks(self): """Registers callbacks for presence router capabilities.""" return self._presence_router.register_presence_router_callbacks + @property + def register_password_auth_provider_callbacks(self): + """Registers callbacks for password auth provider capabilities.""" + return self._password_auth_provider.register_password_auth_provider_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/synapse/server.py b/synapse/server.py index 5bc045d615b4..a64c846d1c49 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -65,7 +65,7 @@ from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler @@ -687,6 +687,10 @@ def get_spam_checker(self) -> SpamChecker: def get_third_party_event_rules(self) -> ThirdPartyEventRules: return ThirdPartyEventRules(self) + @cache_in_self + def get_password_auth_provider(self) -> PasswordAuthProvider: + return PasswordAuthProvider() + @cache_in_self def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker.worker_app: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 11ca47ea2825..1629d2a53c2c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -549,6 +549,8 @@ def _apply_module_schemas( database_engine: config: application config """ + # This is the old way for password_auth_provider modules to make changes + # to the database. This should instead be done using the module API for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f5363a..7dd4a5a36764 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ def check_password(self, *args): return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ def check_auth(self, *args): return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ def check_auth(self, *args): return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ def setUp(self): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ def test_password_only_auth_provider_login(self): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ def test_password_only_auth_provider_ui_auth(self): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ def test_local_user_fallback_login(self): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ def test_local_user_fallback_ui_auth(self): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ def test_no_local_user_fallback_login(self): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ def test_no_local_user_fallback_ui_auth(self): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ def test_password_auth_disabled(self): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ def test_custom_auth_provider_login(self): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ def test_custom_auth_provider_login(self): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ def test_custom_auth_provider_login(self): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ def test_custom_auth_provider_ui_auth(self): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ def test_custom_auth_provider_ui_auth(self): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ def test_custom_auth_provider_ui_auth(self): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -410,10 +518,22 @@ def test_custom_auth_provider_callback(self): for p in ["user_id", "access_token", "device_id", "home_server"]: self.assertIn(p, call_args[0]) + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -425,6 +545,15 @@ def test_custom_auth_password_disabled(self): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + @override_config( { **providers_config(CustomAuthProvider), @@ -432,6 +561,9 @@ def test_custom_auth_password_disabled(self): } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -448,6 +580,15 @@ def test_custom_auth_password_disabled_localdb_enabled(self): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + @override_config( { **providers_config(PasswordCustomAuthProvider), @@ -455,6 +596,9 @@ def test_custom_auth_password_disabled_localdb_enabled(self): } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ def test_password_custom_auth_password_disabled_login(self): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ def test_password_custom_auth_password_disabled_login(self): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ def test_password_custom_auth_password_disabled_ui_auth(self): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ def test_password_custom_auth_password_disabled_ui_auth(self): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ def test_password_custom_auth_password_disabled_ui_auth(self): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass")