From 76f383152cf356ba86ca4ace01ee20d4e8901606 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Nov 2021 15:03:35 -0500 Subject: [PATCH 1/8] Add missing type hints. --- synapse/appservice/__init__.py | 30 ++++++++++++++++-------------- synapse/config/appservice.py | 3 +-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 6504c6bd3f59..8dfc29c13737 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Iterable, List, Match, Optional, Pattern + +from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -49,17 +51,17 @@ class ApplicationService: def __init__( self, - token, - hostname, - id, - sender, - url=None, - namespaces=None, - hs_token=None, - protocols=None, - rate_limited=True, - ip_range_whitelist=None, - supports_ephemeral=False, + token: str, + hostname: str, + id: str, + sender: str, + url: Optional[str] = None, + namespaces: Optional[JsonDict] = None, + hs_token: Optional[str] = None, + protocols: Iterable[str] = None, + rate_limited: bool = True, + ip_range_whitelist: Optional[IPSet] = None, + supports_ephemeral: bool = False, ): self.token = token self.url = ( @@ -284,7 +286,7 @@ def is_exclusive_alias(self, alias: str) -> bool: def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exclusive_user_regexes(self): + def get_exclusive_user_regexes(self) -> List[Pattern]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """ @@ -312,7 +314,7 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]: def is_rate_limited(self) -> bool: return self.rate_limited - def __str__(self): + def __str__(self) -> str: # copy dictionary and redact token fields so they don't get logged dict_copy = self.__dict__.copy() dict_copy["token"] = "" diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 1ebea88db2a2..2d6e7585911f 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -142,8 +142,7 @@ def _load_appservice(hostname, as_info, config_filename): # protocols check protocols = as_info.get("protocols") if protocols: - # Because strings are lists in python - if isinstance(protocols, str) or not isinstance(protocols, list): + if not isinstance(protocols, list): raise KeyError("Optional 'protocols' must be a list if present.") for p in protocols: if not isinstance(p, str): From 2f0912cbeae068c46bcf0680bcd412460714cd36 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Nov 2021 15:05:57 -0500 Subject: [PATCH 2/8] Swap ordering of parameters. --- synapse/appservice/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 8dfc29c13737..12edcb80f19f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -58,7 +58,7 @@ def __init__( url: Optional[str] = None, namespaces: Optional[JsonDict] = None, hs_token: Optional[str] = None, - protocols: Iterable[str] = None, + protocols: Optional[Iterable[str]] = None, rate_limited: bool = True, ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, @@ -133,14 +133,14 @@ def _check_namespaces(self, namespaces): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces - def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: + def _matches_regex(self, namespace_key: str, test_string: str) -> Optional[Match]: for regex_obj in self.namespaces[namespace_key]: if regex_obj["regex"].match(test_string): return regex_obj return None - def _is_exclusive(self, ns_key: str, test_string: str) -> bool: - regex_obj = self._matches_regex(test_string, ns_key) + def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: + regex_obj = self._matches_regex(namespace_key, test_string) if regex_obj: return regex_obj["exclusive"] return False @@ -261,15 +261,15 @@ async def is_interested_in_presence( def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) + bool(self._matches_regex(ApplicationService.NS_USERS), user_id) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) + return bool(self._matches_regex(ApplicationService.NS_ALIASES), alias) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) + return bool(self._matches_regex(ApplicationService.NS_ROOMS), room_id) def is_exclusive_user(self, user_id: str) -> bool: return ( From 9e463851cb66a4db214b1d1abd8018097cd1e773 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Nov 2021 15:16:42 -0500 Subject: [PATCH 3/8] Use attrs for appservice namespaces. --- synapse/appservice/__init__.py | 83 ++++++++++++++++++----------- tests/appservice/test_appservice.py | 51 +++++++----------- tests/storage/test_appservice.py | 8 +-- 3 files changed, 76 insertions(+), 66 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 12edcb80f19f..c53207221a30 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,8 +13,9 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Iterable, List, Match, Optional, Pattern +from typing import TYPE_CHECKING, Iterable, List, Optional, Pattern +import attr from netaddr import IPSet from synapse.api.constants import EventTypes @@ -34,6 +35,20 @@ class ApplicationServiceState: UP = "up" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Namespace: + exclusive: bool + group_id: Optional[str] + regex: Pattern + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Namespaces: + users: List[Namespace] + aliases: List[Namespace] + rooms: List[Namespace] + + class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -86,27 +101,30 @@ def __init__( self.rate_limited = rate_limited - def _check_namespaces(self, namespaces): + def _check_namespaces(self, namespaces: Optional[JsonDict]) -> Namespaces: # Sanity check that it is of the form: # { # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } + result = Namespaces([], [], []) if not namespaces: - namespaces = {} + return result for ns in ApplicationService.NS_LIST: if ns not in namespaces: - namespaces[ns] = [] continue - if type(namespaces[ns]) != list: + namespace: List[Namespace] = getattr(result, ns) + + if not isinstance(namespaces[ns], list): raise ValueError("Bad namespace value for '%s'" % ns) for regex_obj in namespaces[ns]: if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) - if not isinstance(regex_obj.get("exclusive"), bool): + exclusive = regex_obj.get("exclusive") + if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: @@ -127,22 +145,26 @@ def _check_namespaces(self, namespaces): ) regex = regex_obj.get("regex") - if isinstance(regex, str): - regex_obj["regex"] = re.compile(regex) # Pre-compile regex - else: + if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) - return namespaces - def _matches_regex(self, namespace_key: str, test_string: str) -> Optional[Match]: - for regex_obj in self.namespaces[namespace_key]: - if regex_obj["regex"].match(test_string): - return regex_obj + # Pre-compile regex. + namespace.append(Namespace(exclusive, group_id, re.compile(regex))) + + return result + + def _matches_regex( + self, namespaces: List[Namespace], test_string: str + ) -> Optional[Namespace]: + for namespace in namespaces: + if namespace.regex.match(test_string): + return namespace return None - def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: - regex_obj = self._matches_regex(namespace_key, test_string) - if regex_obj: - return regex_obj["exclusive"] + def _is_exclusive(self, namespaces: List[Namespace], test_string: str) -> bool: + namespace = self._matches_regex(namespaces, test_string) + if namespace: + return namespace.exclusive return False async def _matches_user( @@ -261,39 +283,38 @@ async def is_interested_in_presence( def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(ApplicationService.NS_USERS), user_id) + bool(self._matches_regex(self.namespaces.users, user_id)) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(ApplicationService.NS_ALIASES), alias) + return bool(self._matches_regex(self.namespaces.aliases, alias)) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(ApplicationService.NS_ROOMS), room_id) + return bool(self._matches_regex(self.namespaces.rooms, room_id)) def is_exclusive_user(self, user_id: str) -> bool: return ( - self._is_exclusive(ApplicationService.NS_USERS, user_id) - or user_id == self.sender + self._is_exclusive(self.namespaces.users, user_id) or user_id == self.sender ) def is_interested_in_protocol(self, protocol: str) -> bool: return protocol in self.protocols def is_exclusive_alias(self, alias: str) -> bool: - return self._is_exclusive(ApplicationService.NS_ALIASES, alias) + return self._is_exclusive(self.namespaces.aliases, alias) def is_exclusive_room(self, room_id: str) -> bool: - return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + return self._is_exclusive(self.namespaces.rooms, room_id) def get_exclusive_user_regexes(self) -> List[Pattern]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """ return [ - regex_obj["regex"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if regex_obj["exclusive"] + namespace.regex + for namespace in self.namespaces.users + if namespace.exclusive ] def get_groups_for_user(self, user_id: str) -> Iterable[str]: @@ -306,9 +327,9 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]: An iterable that yields group_id strings. """ return ( - regex_obj["group_id"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + namespace.group_id + for namespace in self.namespaces.users + if namespace.group_id and namespace.regex.match(user_id) ) def is_rate_limited(self) -> bool: diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index f386b5e128bf..e13ad6f9019c 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -16,13 +16,13 @@ from twisted.internet import defer -from synapse.appservice import ApplicationService +from synapse.appservice import ApplicationService, Namespace from tests import unittest -def _regex(regex, exclusive=True): - return {"regex": re.compile(regex), "exclusive": exclusive} +def _regex(regex: str, exclusive: bool = True) -> Namespace: + return Namespace(exclusive, None, re.compile(regex)) class ApplicationServiceTestCase(unittest.TestCase): @@ -33,11 +33,6 @@ def setUp(self): url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user - namespaces={ - ApplicationService.NS_USERS: [], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, ) self.event = Mock( type="m.something", room_id="!foo:bar", sender="@someone:somewhere" @@ -47,7 +42,7 @@ def setUp(self): @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.users.append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( (yield defer.ensureDeferred(self.service.is_interested(self.event))) @@ -55,7 +50,7 @@ def test_regex_user_id_prefix_match(self): @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.users.append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( (yield defer.ensureDeferred(self.service.is_interested(self.event))) @@ -63,7 +58,7 @@ def test_regex_user_id_prefix_no_match(self): @defer.inlineCallbacks def test_regex_room_member_is_checked(self): - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.users.append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" @@ -73,7 +68,7 @@ def test_regex_room_member_is_checked(self): @defer.inlineCallbacks def test_regex_room_id_match(self): - self.service.namespaces[ApplicationService.NS_ROOMS].append( + self.service.namespaces.rooms.append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" @@ -83,7 +78,7 @@ def test_regex_room_id_match(self): @defer.inlineCallbacks def test_regex_room_id_no_match(self): - self.service.namespaces[ApplicationService.NS_ROOMS].append( + self.service.namespaces.rooms.append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" @@ -93,9 +88,7 @@ def test_regex_room_id_no_match(self): @defer.inlineCallbacks def test_regex_alias_match(self): - self.service.namespaces[ApplicationService.NS_ALIASES].append( - _regex("#irc_.*:matrix.org") - ) + self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) self.store.get_aliases_for_room.return_value = defer.succeed( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) @@ -109,46 +102,44 @@ def test_regex_alias_match(self): ) def test_non_exclusive_alias(self): - self.service.namespaces[ApplicationService.NS_ALIASES].append( + self.service.namespaces.aliases.append( _regex("#irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_non_exclusive_room(self): - self.service.namespaces[ApplicationService.NS_ROOMS].append( + self.service.namespaces.rooms.append( _regex("!irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) def test_non_exclusive_user(self): - self.service.namespaces[ApplicationService.NS_USERS].append( + self.service.namespaces.users.append( _regex("@irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_alias(self): - self.service.namespaces[ApplicationService.NS_ALIASES].append( + self.service.namespaces.aliases.append( _regex("#irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_exclusive_user(self): - self.service.namespaces[ApplicationService.NS_USERS].append( + self.service.namespaces.users.append( _regex("@irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_room(self): - self.service.namespaces[ApplicationService.NS_ROOMS].append( + self.service.namespaces.rooms.append( _regex("!irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) @defer.inlineCallbacks def test_regex_alias_no_match(self): - self.service.namespaces[ApplicationService.NS_ALIASES].append( - _regex("#irc_.*:matrix.org") - ) + self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) self.store.get_aliases_for_room.return_value = defer.succeed( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) @@ -163,10 +154,8 @@ def test_regex_alias_no_match(self): @defer.inlineCallbacks def test_regex_multiple_matches(self): - self.service.namespaces[ApplicationService.NS_ALIASES].append( - _regex("#irc_.*:matrix.org") - ) - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) + self.service.namespaces.users.append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room.return_value = defer.succeed( ["#irc_barfoo:matrix.org"] @@ -184,7 +173,7 @@ def test_regex_multiple_matches(self): def test_interested_in_self(self): # make sure invites get through self.service.sender = "@appservice:name" - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.users.append(_regex("@irc_.*")) self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender @@ -194,7 +183,7 @@ def test_interested_in_self(self): @defer.inlineCallbacks def test_member_list_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) + self.service.namespaces.users.append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. self.store.get_users_in_room.return_value = defer.succeed( ["@alice:here", "@irc_fo:here", "@bob:here"] diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index f26d5acf9c29..3d52ffa24d9d 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -20,7 +20,7 @@ from twisted.internet import defer -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ApplicationServiceState from synapse.config._base import ConfigError from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( @@ -89,9 +89,9 @@ def test_retrieval_of_service(self): self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) + self.assertEquals(stored_service.namespaces.aliases, []) + self.assertEquals(stored_service.namespaces.rooms, []) + self.assertEquals(stored_service.namespaces.users, []) def test_retrieval_of_all_services(self): services = self.store.get_app_services() From bf838f5e7413add5ec161f1476d55c2adbdaa7ea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Nov 2021 15:27:52 -0500 Subject: [PATCH 4/8] Add type hints to scheduler. --- synapse/appservice/scheduler.py | 74 ++++++++++++-------- synapse/storage/databases/main/appservice.py | 6 +- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6a2ce99b55dc..185e3a527815 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,13 +48,19 @@ components. """ import logging -from typing import List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict +from synapse.util import Clock + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -72,7 +78,7 @@ class ApplicationServiceScheduler: case is a simple array. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.as_api = hs.get_application_service_api() @@ -80,7 +86,7 @@ def __init__(self, hs): self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - async def start(self): + async def start(self) -> None: logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. @@ -91,12 +97,14 @@ async def start(self): for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service: ApplicationService, event: EventBase): + def submit_event_for_as( + self, service: ApplicationService, event: EventBase + ) -> None: self.queuer.enqueue_event(service, event) def submit_ephemeral_events_for_as( self, service: ApplicationService, events: List[JsonDict] - ): + ) -> None: self.queuer.enqueue_ephemeral(service, events) @@ -108,16 +116,18 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): - self.queued_events = {} # dict of {service_id: [events]} - self.queued_ephemeral = {} # dict of {service_id: [events]} + def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + # dict of {service_id: [events]} + self.queued_events: Dict[str, List[EventBase]] = {} + # dict of {service_id: [events]} + self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service): + def _start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -126,15 +136,17 @@ def _start_background_request(self, service): "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase): + def enqueue_event(self, service: ApplicationService, event: EventBase) -> None: self.queued_events.setdefault(service.id, []).append(event) self._start_background_request(service) - def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + def enqueue_ephemeral( + self, service: ApplicationService, events: List[JsonDict] + ) -> None: self.queued_ephemeral.setdefault(service.id, []).extend(events) self._start_background_request(service) - async def _send_request(self, service: ApplicationService): + async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -168,20 +180,15 @@ class _TransactionController: if a transaction fails. (Note we have only have one of these in the homeserver.) - - Args: - clock (synapse.util.Clock): - store (synapse.storage.DataStore): - as_api (synapse.appservice.api.ApplicationServiceApi): """ - def __init__(self, clock, store, as_api): + def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi): self.clock = clock self.store = store self.as_api = as_api # map from service id to recoverer instance - self.recoverers = {} + self.recoverers: Dict[str, "_Recoverer"] = {} # for UTs self.RECOVERER_CLASS = _Recoverer @@ -191,7 +198,7 @@ async def send( service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, - ): + ) -> None: try: txn = await self.store.create_appservice_txn( service=service, events=events, ephemeral=ephemeral or [] @@ -207,7 +214,7 @@ async def send( logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - async def on_recovered(self, recoverer): + async def on_recovered(self, recoverer: "_Recoverer") -> None: logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) @@ -217,18 +224,18 @@ async def on_recovered(self, recoverer): recoverer.service, ApplicationServiceState.UP ) - async def _on_txn_fail(self, service): + async def _on_txn_fail(self, service: ApplicationService) -> None: try: await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") - def start_recoverer(self, service): + def start_recoverer(self, service: ApplicationService) -> None: """Start a Recoverer for the given service Args: - service (synapse.appservice.ApplicationService): + service: """ logger.info("Starting recoverer for AS ID %s", service.id) assert service.id not in self.recoverers @@ -257,7 +264,14 @@ class _Recoverer: callback (callable[_Recoverer]): called once the service recovers. """ - def __init__(self, clock, store, as_api, service, callback): + def __init__( + self, + clock: Clock, + store: DataStore, + as_api: ApplicationServiceApi, + service: ApplicationService, + callback: Callable[["_Recoverer"], Awaitable[None]], + ): self.clock = clock self.store = store self.as_api = as_api @@ -265,8 +279,8 @@ def __init__(self, clock, store, as_api, service, callback): self.callback = callback self.backoff_counter = 1 - def recover(self): - def _retry(): + def recover(self) -> None: + def _retry() -> None: run_as_background_process( "as-recoverer-%s" % (self.service.id,), self.retry ) @@ -275,13 +289,13 @@ def _retry(): logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.clock.call_later(delay, _retry) - def _backoff(self): + def _backoff(self) -> None: # cap the backoff to be around 8.5min => (2^9) = 512 secs if self.backoff_counter < 9: self.backoff_counter += 1 self.recover() - async def retry(self): + async def retry(self) -> None: logger.info("Starting retries on %s", self.service.id) try: while True: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index baec35ee27b2..b8e3faeb91ca 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -132,9 +132,7 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state( - self, state: ApplicationServiceState - ) -> List[ApplicationService]: + async def get_appservices_by_state(self, state: str) -> List[ApplicationService]: """Get a list of application services based on their state. Args: @@ -177,7 +175,7 @@ async def get_appservice_state( return None async def set_appservice_state( - self, service: ApplicationService, state: ApplicationServiceState + self, service: ApplicationService, state: str ) -> None: """Set the application service state. From 2a1449f20055fffe63355afebad1f6dc0095890e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Nov 2021 10:58:42 -0500 Subject: [PATCH 5/8] Add type hints to appservice api. --- mypy.ini | 3 +++ synapse/appservice/api.py | 48 ++++++++++++++++++++++++++++----------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/mypy.ini b/mypy.ini index f32c6c41a347..a4883b8395c5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -151,6 +151,9 @@ disallow_untyped_defs = True [mypy-synapse.app.*] disallow_untyped_defs = True +[mypy-synapse.appservice.*] +disallow_untyped_defs = True + [mypy-synapse.crypto.*] disallow_untyped_defs = True diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d08f6bbd7f2e..e91ba60747a0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib -from typing import TYPE_CHECKING, List, Optional, Tuple +import urllib.parse +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from prometheus_client import Counter @@ -53,7 +53,7 @@ APP_SERVICE_PREFIX = "/_matrix/app/unstable" -def _is_valid_3pe_metadata(info): +def _is_valid_3pe_metadata(info: JsonDict) -> bool: if "instances" not in info: return False if not isinstance(info["instances"], list): @@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info): return True -def _is_valid_3pe_result(r, field): +def _is_valid_3pe_result(r: JsonDict, field: str) -> bool: if not isinstance(r, dict): return False @@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - async def query_user(self, service, user_id): + async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -109,9 +113,13 @@ async def query_user(self, service, user_id): logger.warning("query_user to %s threw exception %s", uri, ex) return False - async def query_alias(self, service, alias): + async def query_alias(self, service: "ApplicationService", alias: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -125,7 +133,13 @@ async def query_alias(self, service, alias): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - async def query_3pe(self, service, kind, protocol, fields): + async def query_3pe( + self, + service: "ApplicationService", + kind: str, + protocol: str, + fields: Dict[bytes, List[bytes]], + ) -> List[JsonDict]: if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -205,11 +219,14 @@ async def push_bulk( events: List[EventBase], ephemeral: List[JsonDict], txn_id: Optional[int] = None, - ): + ) -> bool: if service.url is None: return True - events = self._serialize(service, events) + # This is required by the configuration. + assert service.hs_token is not None + + serialized_events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -221,9 +238,12 @@ async def push_bulk( # Never send ephemeral events to appservices that do not support it if service.supports_ephemeral: - body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + body = { + "events": serialized_events, + "de.sorunome.msc2409.ephemeral": ephemeral, + } else: - body = {"events": events} + body = {"events": serialized_events} try: await self.put_json( @@ -232,7 +252,7 @@ async def push_bulk( args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() - sent_events_counter.labels(service.id).inc(len(events)) + sent_events_counter.labels(service.id).inc(len(serialized_events)) return True except CodeMessageException as e: logger.warning("push_bulk to %s received %s", uri, e.code) @@ -241,7 +261,9 @@ async def push_bulk( failed_transactions_counter.labels(service.id).inc() return False - def _serialize(self, service, events): + def _serialize( + self, service: "ApplicationService", events: Iterable[EventBase] + ) -> List[JsonDict]: time_now = self.clock.time_msec() return [ serialize_event( From 60abe63d21a5696e65d9195ca5b79f2cc7f3f191 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Nov 2021 13:12:02 -0500 Subject: [PATCH 6/8] Newsfragment --- changelog.d/11360.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/11360.misc diff --git a/changelog.d/11360.misc b/changelog.d/11360.misc new file mode 100644 index 000000000000..43e25720c5e2 --- /dev/null +++ b/changelog.d/11360.misc @@ -0,0 +1 @@ +Add type hints to `synapse.appservice`. From 17b65fa664b696635884c1753d28f240ceffb827 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 29 Nov 2021 13:23:16 -0500 Subject: [PATCH 7/8] Partially rollback 9e463851cb66a4db214b1d1abd8018097cd1e773 to remove Namespaces class. --- synapse/appservice/__init__.py | 51 ++++++++++++++--------------- tests/appservice/test_appservice.py | 40 ++++++++++++---------- tests/storage/test_appservice.py | 8 ++--- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index c53207221a30..75733a846f95 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Iterable, List, Optional, Pattern +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern import attr from netaddr import IPSet @@ -42,13 +42,6 @@ class Namespace: regex: Pattern -@attr.s(slots=True, frozen=True, auto_attribs=True) -class Namespaces: - users: List[Namespace] - aliases: List[Namespace] - rooms: List[Namespace] - - class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -101,23 +94,26 @@ def __init__( self.rate_limited = rate_limited - def _check_namespaces(self, namespaces: Optional[JsonDict]) -> Namespaces: + def _check_namespaces( + self, namespaces: Optional[JsonDict] + ) -> Dict[str, List[Namespace]]: # Sanity check that it is of the form: # { # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } - result = Namespaces([], [], []) - if not namespaces: - return result + if namespaces is None: + namespaces = {} + + result: Dict[str, List[Namespace]] = {} for ns in ApplicationService.NS_LIST: + result[ns] = [] + if ns not in namespaces: continue - namespace: List[Namespace] = getattr(result, ns) - if not isinstance(namespaces[ns], list): raise ValueError("Bad namespace value for '%s'" % ns) for regex_obj in namespaces[ns]: @@ -149,20 +145,20 @@ def _check_namespaces(self, namespaces: Optional[JsonDict]) -> Namespaces: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) # Pre-compile regex. - namespace.append(Namespace(exclusive, group_id, re.compile(regex))) + result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) return result def _matches_regex( - self, namespaces: List[Namespace], test_string: str + self, namespace_key: str, test_string: str ) -> Optional[Namespace]: - for namespace in namespaces: + for namespace in self.namespaces[namespace_key]: if namespace.regex.match(test_string): return namespace return None - def _is_exclusive(self, namespaces: List[Namespace], test_string: str) -> bool: - namespace = self._matches_regex(namespaces, test_string) + def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: + namespace = self._matches_regex(namespace_key, test_string) if namespace: return namespace.exclusive return False @@ -283,29 +279,30 @@ async def is_interested_in_presence( def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(self.namespaces.users, user_id)) + bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(self.namespaces.aliases, alias)) + return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(self.namespaces.rooms, room_id)) + return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: return ( - self._is_exclusive(self.namespaces.users, user_id) or user_id == self.sender + self._is_exclusive(ApplicationService.NS_USERS, user_id) + or user_id == self.sender ) def is_interested_in_protocol(self, protocol: str) -> bool: return protocol in self.protocols def is_exclusive_alias(self, alias: str) -> bool: - return self._is_exclusive(self.namespaces.aliases, alias) + return self._is_exclusive(ApplicationService.NS_ALIASES, alias) def is_exclusive_room(self, room_id: str) -> bool: - return self._is_exclusive(self.namespaces.rooms, room_id) + return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) def get_exclusive_user_regexes(self) -> List[Pattern]: """Get the list of regexes used to determine if a user is exclusively @@ -313,7 +310,7 @@ def get_exclusive_user_regexes(self) -> List[Pattern]: """ return [ namespace.regex - for namespace in self.namespaces.users + for namespace in self.namespaces[ApplicationService.NS_USERS] if namespace.exclusive ] @@ -328,7 +325,7 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]: """ return ( namespace.group_id - for namespace in self.namespaces.users + for namespace in self.namespaces[ApplicationService.NS_USERS] if namespace.group_id and namespace.regex.match(user_id) ) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index e13ad6f9019c..ba2a2bfd64ad 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -42,7 +42,7 @@ def setUp(self): @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( (yield defer.ensureDeferred(self.service.is_interested(self.event))) @@ -50,7 +50,7 @@ def test_regex_user_id_prefix_match(self): @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( (yield defer.ensureDeferred(self.service.is_interested(self.event))) @@ -58,7 +58,7 @@ def test_regex_user_id_prefix_no_match(self): @defer.inlineCallbacks def test_regex_room_member_is_checked(self): - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" @@ -68,7 +68,7 @@ def test_regex_room_member_is_checked(self): @defer.inlineCallbacks def test_regex_room_id_match(self): - self.service.namespaces.rooms.append( + self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" @@ -78,7 +78,7 @@ def test_regex_room_id_match(self): @defer.inlineCallbacks def test_regex_room_id_no_match(self): - self.service.namespaces.rooms.append( + self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" @@ -88,7 +88,9 @@ def test_regex_room_id_no_match(self): @defer.inlineCallbacks def test_regex_alias_match(self): - self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) + self.service.namespaces[ApplicationService.NS_ALIASES].append( + _regex("#irc_.*:matrix.org") + ) self.store.get_aliases_for_room.return_value = defer.succeed( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) @@ -102,44 +104,46 @@ def test_regex_alias_match(self): ) def test_non_exclusive_alias(self): - self.service.namespaces.aliases.append( + self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_non_exclusive_room(self): - self.service.namespaces.rooms.append( + self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) def test_non_exclusive_user(self): - self.service.namespaces.users.append( + self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_alias(self): - self.service.namespaces.aliases.append( + self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_exclusive_user(self): - self.service.namespaces.users.append( + self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_room(self): - self.service.namespaces.rooms.append( + self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) @defer.inlineCallbacks def test_regex_alias_no_match(self): - self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) + self.service.namespaces[ApplicationService.NS_ALIASES].append( + _regex("#irc_.*:matrix.org") + ) self.store.get_aliases_for_room.return_value = defer.succeed( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) @@ -154,8 +158,10 @@ def test_regex_alias_no_match(self): @defer.inlineCallbacks def test_regex_multiple_matches(self): - self.service.namespaces.aliases.append(_regex("#irc_.*:matrix.org")) - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_ALIASES].append( + _regex("#irc_.*:matrix.org") + ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room.return_value = defer.succeed( ["#irc_barfoo:matrix.org"] @@ -173,7 +179,7 @@ def test_regex_multiple_matches(self): def test_interested_in_self(self): # make sure invites get through self.service.sender = "@appservice:name" - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender @@ -183,7 +189,7 @@ def test_interested_in_self(self): @defer.inlineCallbacks def test_member_list_match(self): - self.service.namespaces.users.append(_regex("@irc_.*")) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. self.store.get_users_in_room.return_value = defer.succeed( ["@alice:here", "@irc_fo:here", "@bob:here"] diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3d52ffa24d9d..f26d5acf9c29 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -20,7 +20,7 @@ from twisted.internet import defer -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( @@ -89,9 +89,9 @@ def test_retrieval_of_service(self): self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces.aliases, []) - self.assertEquals(stored_service.namespaces.rooms, []) - self.assertEquals(stored_service.namespaces.users, []) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self): services = self.store.get_app_services() From d1a36c70c8f26b6f770a01853a54b1b0c0b9fe39 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Dec 2021 11:34:29 -0500 Subject: [PATCH 8/8] Declare type of Pattern. --- synapse/appservice/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 664665284a1c..8c9ff93b2c13 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -41,7 +41,7 @@ class ApplicationServiceState(Enum): class Namespace: exclusive: bool group_id: Optional[str] - regex: Pattern + regex: Pattern[str] class ApplicationService: @@ -306,7 +306,7 @@ def is_exclusive_alias(self, alias: str) -> bool: def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exclusive_user_regexes(self) -> List[Pattern]: + def get_exclusive_user_regexes(self) -> List[Pattern[str]]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """