diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc deleted file mode 100644 index d3ef0b3777..0000000000 --- a/changelog.d/17333.misc +++ /dev/null @@ -1 +0,0 @@ -Handle device lists notifications for large accounts more efficiently in worker mode. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3dddbb70b4..2d6d49eed7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -114,19 +114,13 @@ async def on_rdata( """ all_room_ids: Set[str] = set() if stream_name == DeviceListsStream.NAME: - if any(not row.is_signature and not row.hosts_calculated for row in rows): + if any(row.entity.startswith("@") and not row.is_signature for row in rows): prev_token = self.store.get_device_stream_token() all_room_ids = await self.store.get_all_device_list_changes( prev_token, token ) self.store.device_lists_in_rooms_have_changed(all_room_ids, token) - # If we're sending federation we need to update the device lists - # outbound pokes stream change cache with updated hosts. - if self.send_handler and any(row.hosts_calculated for row in rows): - hosts = await self.store.get_destinations_for_device(token) - self.store.device_lists_outbound_pokes_have_changed(hosts, token) - self.store.process_replication_rows(stream_name, instance_name, token, rows) # NOTE: this must be called after process_replication_rows to ensure any # cache invalidations are first handled before any stream ID advances. @@ -439,11 +433,12 @@ async def process_replication_rows( # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - if any(row.hosts_calculated for row in rows): - hosts = await self.store.get_destinations_for_device(token) - await self.federation_sender.send_device_messages( - hosts, immediate=False - ) + hosts = { + row.entity + for row in rows + if not row.entity.startswith("@") and not row.is_signature + } + await self.federation_sender.send_device_messages(hosts, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d021904de7..661206c841 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -549,14 +549,10 @@ class DeviceListsStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - user_id: str + entity: str # Indicates that a user has signed their own device with their user-signing key is_signature: bool - # Indicates if this is a notification that we've calculated the hosts we - # need to send the update to. - hosts_calculated: bool - NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -598,13 +594,13 @@ async def _update_function( upper_limit_token = min(upper_limit_token, signatures_to_token) device_updates = [ - (stream_id, (entity, False, hosts)) - for stream_id, (entity, hosts) in device_updates + (stream_id, (entity, False)) + for stream_id, (entity,) in device_updates if stream_id <= upper_limit_token ] signatures_updates = [ - (stream_id, (entity, True, False)) + (stream_id, (entity, True)) for stream_id, (entity,) in signatures_updates if stream_id <= upper_limit_token ] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 5eeca6165d..40187496e2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -164,24 +164,22 @@ def __init__( prefilled_cache=user_signature_stream_prefill, ) - self._device_list_federation_stream_cache = None - if hs.should_send_federation(): - ( - device_list_federation_prefill, - device_list_federation_list_id, - ) = self.db_pool.get_cache_dict( - db_conn, - "device_lists_outbound_pokes", - entity_column="destination", - stream_column="stream_id", - max_value=device_list_max, - limit=10000, - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", - device_list_federation_list_id, - prefilled_cache=device_list_federation_prefill, - ) + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) if hs.config.worker.run_background_tasks: self._clock.looping_call( @@ -209,29 +207,22 @@ def _invalidate_caches_for_devices( ) -> None: for row in rows: if row.is_signature: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + self._user_signature_stream_cache.entity_has_changed(row.entity, token) continue # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - if not row.hosts_calculated: - self._device_list_stream_cache.entity_has_changed(row.user_id, token) - self.get_cached_devices_for_user.invalidate((row.user_id,)) - self._get_cached_user_device.invalidate((row.user_id,)) - self.get_device_list_last_stream_id_for_remote.invalidate( - (row.user_id,) - ) + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - def device_lists_outbound_pokes_have_changed( - self, destinations: StrCollection, token: int - ) -> None: - assert self._device_list_federation_stream_cache is not None - - for destination in destinations: - self._device_list_federation_stream_cache.entity_has_changed( - destination, token - ) + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) def device_lists_in_rooms_have_changed( self, room_ids: StrCollection, token: int @@ -372,11 +363,6 @@ async def get_device_updates_by_remote( EDU contents. """ now_stream_id = self.get_device_stream_token() - if from_stream_id == now_stream_id: - return now_stream_id, [] - - if self._device_list_federation_stream_cache is None: - raise Exception("Func can only be used on federation senders") has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -1032,10 +1018,10 @@ def _get_all_device_list_changes_for_remotes( # This query Does The Right Thing where it'll correctly apply the # bounds to the inner queries. sql = """ - SELECT stream_id, user_id, hosts FROM ( - SELECT stream_id, user_id, false AS hosts FROM device_lists_stream + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream UNION ALL - SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -1591,14 +1577,6 @@ def get_device_list_changes_in_room_txn( get_device_list_changes_in_room_txn, ) - async def get_destinations_for_device(self, stream_id: int) -> StrCollection: - return await self.db_pool.simple_select_onecol( - table="device_lists_outbound_pokes", - keyvalues={"stream_id": stream_id}, - retcol="destination", - desc="get_destinations_for_device", - ) - class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -2134,13 +2112,12 @@ def _add_device_outbound_poke_to_stream_txn( stream_ids: List[int], context: Optional[Dict[str, str]], ) -> None: - if self._device_list_federation_stream_cache: - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_ids[-1], - ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) now = self._clock.time_msec() stream_id_iterator = iter(stream_ids) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9e6c9561ae..38d8785faa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -123,9 +123,9 @@ def process_replication_rows( if stream_name == DeviceListsStream.NAME: for row in rows: assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) - if not row.hosts_calculated: + if row.entity.startswith("@"): self._get_e2e_device_keys_for_federation_query_inner.invalidate( - (row.user_id,) + (row.entity,) ) super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index ba01b038ab..7f975d04ff 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -36,14 +36,6 @@ class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def default_config(self) -> JsonDict: - config = super().default_config() - - # We 'enable' federation otherwise `get_device_updates_by_remote` will - # throw an exception. - config["federation_sender_instances"] = ["master"] - return config - def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table.