Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to tests.handlers. #14680

Merged
merged 8 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,7 +2031,7 @@ def __init__(self) -> None:
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []

# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}

# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
Expand Down
54 changes: 29 additions & 25 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface"""

from http import HTTPStatus
from typing import Any, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
Expand All @@ -44,46 +45,46 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""

@staticmethod
def parse_config(self) -> None:
def parse_config() -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
pass

def check_password(self, *args):
def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)


class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""

@staticmethod
def parse_config(self) -> None:
def parse_config() -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass

def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}

def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)


class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""

@staticmethod
def parse_config(self) -> None:
def parse_config() -> None:
pass

def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)

def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)


Expand All @@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type."""

@staticmethod
def parse_config(self) -> None:
def parse_config() -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass

def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}

def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)


Expand All @@ -110,21 +111,21 @@ class PasswordCustomAuthProvider:
as well as a password login"""

@staticmethod
def parse_config(self) -> None:
def parse_config() -> None:
pass

def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
}
)

def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)

def check_pass(self, *args):
def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)


Expand Down Expand Up @@ -720,7 +721,9 @@ def test_on_logged_out(self) -> None:

self.called = False

async def on_logged_out(user_id, device_id, access_token):
async def on_logged_out(
user_id: str, device_id: Optional[str], access_token: str
) -> None:
self.called = True

on_logged_out = Mock(side_effect=on_logged_out)
Expand Down Expand Up @@ -841,7 +844,7 @@ def test_displayname_uia(self) -> None:
# Check that the callback has been called.
m.assert_called_once()

def _test_3pid_allowed(self, username: str, registration: bool):
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.

Expand Down Expand Up @@ -907,7 +910,7 @@ def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
client is trying to register.
"""

async def callback(uia_results, params):
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
Expand Down Expand Up @@ -950,12 +953,13 @@ def _get_login_flows(self) -> JsonDict:
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)

def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel

def _start_delete_device_session(self, access_token, device_id) -> str:
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
Expand Down