From 99b350eeacabd399b4604cbcfb5c257db79730fc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 08:42:46 +0100 Subject: [PATCH 01/13] Change compute_event_context to take state_ids --- synapse/handlers/federation_event.py | 15 ++++++++++++--- synapse/handlers/message.py | 7 ++++++- synapse/state/__init__.py | 22 ++++++++++------------ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index ca82df8a6d9e..da843dc64f20 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -463,7 +463,9 @@ async def process_remote_join( with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -513,11 +515,14 @@ async def update_state_for_partial_state_event( # This is the same operation as we do when we receive a regular event # over federation. state = await self._resolve_state_at_missing_prevs(destination, event) + state_ids = None + if state: + state_ids = {(e.type, e.state_key): e.event_id for e in state} # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state_ids, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -1089,8 +1094,12 @@ async def _process_received_pdu( assert not event.internal_metadata.outlier try: + state_ids = None + if state: + state_ids = {(e.type, e.state_key): e.event_id for e in state} context = await self._state_handler.compute_event_context( - event, old_state=state + event, + state_ids_before_event=state_ids, ) context = await self._check_event_auth( origin, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cb1bc4c06f1c..c57f45a9a57d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1023,7 +1023,12 @@ async def create_new_client_event( # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. old_state = await self.store.get_events_as_list(state_event_ids) - context = await self.state.compute_event_context(event, old_state=old_state) + context = await self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) else: context = await self.state.compute_event_context(event) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff33..ef444cac6b1f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -261,7 +261,7 @@ async def get_hosts_in_room_at_events( async def compute_event_context( self, event: EventBase, - old_state: Optional[Iterable[EventBase]] = None, + state_ids_before_event: Optional[StateMap[str]] = None, partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,12 +273,12 @@ async def compute_event_context( Args: event: - old_state: The state at the event if it can't be - calculated from existing events. This is normally only specified - when receiving an event from federation where we don't have the - prev events for, e.g. when backfilling. - partial_state: True if `old_state` is partial and omits non-critical - membership events + state_ids_before_event: The event ids of the state at the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events for, e.g. when backfilling. + partial_state: True if `state_ids_before_event` is partial and omits + non-critical membership events Returns: The event context. """ @@ -286,13 +286,11 @@ async def compute_event_context( assert not event.internal_metadata.is_outlier() # - # first of all, figure out the state before the event + # first of all, figure out the state before the event, unless we + # already have it. # - if old_state: + if state_ids_before_event: # if we're given the state before the event, then we use that - state_ids_before_event: StateMap[str] = { - (s.type, s.state_key): s.event_id for s in old_state - } state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None From 05474e932aef2f4ea5a372adfd02c6b81823636a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:05:51 +0100 Subject: [PATCH 02/13] Change _process_received_pdu --- synapse/handlers/federation_event.py | 43 ++++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index da843dc64f20..fe6e477ab23c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -274,7 +274,7 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: affected=pdu.event_id, ) - await self._process_received_pdu(origin, pdu, state=None) + await self._process_received_pdu(origin, pdu, state_ids=None) async def on_send_membership_event( self, origin: str, event: EventBase @@ -775,8 +775,13 @@ async def _process_pulled_event( state = await self._resolve_state_at_missing_prevs(origin, event) # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does # not return partial state + + state_ids = None + if state: + state_ids = {(e.type, e.state_key): e.event_id for e in state} + await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled + origin, event, state_ids=state_ids, backfilled=backfilled ) except FederationError as e: if e.code == 403: @@ -1061,7 +1066,7 @@ async def _process_received_pdu( self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1083,7 +1088,7 @@ async def _process_received_pdu( event: event to be persisted - state: Normally None, but if we are handling a gap in the graph + state_ids: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event @@ -1094,9 +1099,6 @@ async def _process_received_pdu( assert not event.internal_metadata.outlier try: - state_ids = None - if state: - state_ids = {(e.type, e.state_key): e.event_id for e in state} context = await self._state_handler.compute_event_context( event, state_ids_before_event=state_ids, @@ -1116,7 +1118,7 @@ async def _process_received_pdu( # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - await self._check_for_soft_fail(event, state, origin=origin) + await self._check_for_soft_fail(event, state_ids, origin=origin) await self._run_push_actions_and_persist_event(event, context, backfilled) @@ -1598,7 +1600,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1606,7 +1608,7 @@ async def _check_for_soft_fail( Args: event - state: The state at the event if we don't have all the event's prev events + state_ids: The state at the event if we don't have all the event's prev events origin: The host the event originates from. """ extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) @@ -1622,7 +1624,7 @@ async def _check_for_soft_fail( room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". - if state is not None: + if state_ids is not None: # If we're explicitly given the state then we won't have all the # prev events, and so we have a gap in the graph. In this case # we want to be a little careful as we might have been down for @@ -1635,17 +1637,20 @@ async def _check_for_soft_fail( # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + state_sets: List[StateMap[str]] = list(state_sets_d.values()) + state_sets.append(state_ids) + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map={}, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids From 35a461476b98437f84b894c6427188d246c5b964 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:22:15 +0100 Subject: [PATCH 03/13] Change _resolve_state_at_missing_prevs --- synapse/handlers/federation_event.py | 32 ++++++---------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index fe6e477ab23c..d30ae366d6d5 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -514,10 +514,7 @@ async def update_state_for_partial_state_event( # # This is the same operation as we do when we receive a regular event # over federation. - state = await self._resolve_state_at_missing_prevs(destination, event) - state_ids = None - if state: - state_ids = {(e.type, e.state_key): e.event_id for e in state} + state_ids = await self._resolve_state_at_missing_prevs(destination, event) # build a new state group for it if need be context = await self._state_handler.compute_event_context( @@ -772,14 +769,10 @@ async def _process_pulled_event( return try: - state = await self._resolve_state_at_missing_prevs(origin, event) + state_ids = await self._resolve_state_at_missing_prevs(origin, event) # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does # not return partial state - state_ids = None - if state: - state_ids = {(e.type, e.state_key): e.event_id for e in state} - await self._process_received_pdu( origin, event, state_ids=state_ids, backfilled=backfilled ) @@ -791,7 +784,7 @@ async def _process_pulled_event( async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -818,8 +811,8 @@ async def _resolve_state_at_missing_prevs( event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the event ids of the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -880,19 +873,6 @@ async def _resolve_state_at_missing_prevs( state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -904,7 +884,7 @@ async def _resolve_state_at_missing_prevs( "We can't get valid state history.", affected=event_id, ) - return state + return state_map async def _get_state_after_missing_prev_event( self, From ea9063d2385487fd4cfad868691de33936766559 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:24:36 +0100 Subject: [PATCH 04/13] Add a get_metadata_for_events DB func --- synapse/storage/databases/main/state.py | 57 +++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 18ae8aee295d..e126020135fa 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,6 +16,8 @@ import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -33,6 +36,7 @@ from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,6 +47,15 @@ MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +146,50 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str: return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events. + + This method is a faster alternative than fetching the full events from + the DB, and should be used when the full event is not needed. + + Returns metadata for rejected and redacted events. Events that have not + been persisted are omitted from the returned dict. + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + batch_ids: Collection[str], + ) -> Dict[str, EventMetadata]: + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", batch_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e + LEFT JOIN state_events USING (event_id) + WHERE {clause} + """ + + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, event_type=event_type, state_key=state_key + ) + for event_id, room_id, event_type, state_key in txn + } + + result_map: Dict[str, EventMetadata] = {} + for batch_ids in batch_iter(event_ids, 1000): + return await self.db_pool.runInteraction( + "get_metadata_for_events", + get_metadata_for_events_txn, + batch_ids=batch_ids, + ) + + return result_map + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. From 7a9586ee932243bc683b3e9c10005b9bd5e5e9b4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:24:47 +0100 Subject: [PATCH 05/13] _get_state_after_missing_prev_event --- synapse/handlers/federation_event.py | 102 +++++++++++++-------------- 1 file changed, 49 insertions(+), 53 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index d30ae366d6d5..bbc20b156b6c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -852,18 +852,14 @@ async def _resolve_state_at_missing_prevs( # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p + remote_state_map = ( + await self._get_state_ids_after_missing_prev_event( + dest, room_id, p + ) ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } state_maps.append(remote_state_map) - for x in remote_state: - event_map[x.event_id] = x - room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, @@ -886,12 +882,12 @@ async def _resolve_state_at_missing_prevs( ) return state_map - async def _get_state_after_missing_prev_event( + async def _get_state_ids_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -900,7 +896,7 @@ async def _get_state_after_missing_prev_event( event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The event ids of the state *after* the given event. """ ( state_event_ids, @@ -919,15 +915,13 @@ async def _get_state_after_missing_prev_event( desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -938,7 +932,7 @@ async def _get_state_after_missing_prev_event( # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -964,47 +958,51 @@ async def _get_state_after_missing_prev_event( destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B + # We now need to fill out the state map, which involves fetching the + # type and state key for each event ID in the state. + state_map = {} - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + event_metadata = await self._store.get_metadata_for_events(state_event_ids) + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + # + # This can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1012,14 +1010,12 @@ async def _get_state_after_missing_prev_event( failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str From 37b9c33aedf7cc8fae07da12eb085c6efdc2ff16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:10:52 +0100 Subject: [PATCH 06/13] Change MessageHandler --- synapse/handlers/message.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c57f45a9a57d..8a9f9de555c0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1022,12 +1022,24 @@ async def create_new_client_event( # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(state_event_ids) + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + raise Exception("State event not persisted %s", state_id) + + if data.state_key is None: + raise Exception( + "Trying to set non-state event as state: %s", state_id + ) + + state_map[(data.event_type, data.state_key)] = state_id + context = await self.state.compute_event_context( event, - state_ids_before_event={ - (e.type, e.state_key): e.event_id for e in old_state - }, + state_ids_before_event=state_map, ) else: context = await self.state.compute_event_context(event) From 0fe4092765b369d385b04fa001db122cda6c68cf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:17:03 +0100 Subject: [PATCH 07/13] Fix tests --- tests/handlers/test_federation.py | 4 ++- tests/storage/test_events.py | 43 ++++++++++++++++++++----------- tests/test_state.py | 14 ++++++++-- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e95dfdce2086..183a6635fa8d 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -276,7 +276,9 @@ def test_backfill_with_many_backward_extremities(self) -> None: # federation handler wanting to backfill the fake event. self.get_success( federation_event_handler._process_received_pdu( - self.OTHER_SERVER_NAME, event, state=current_state + self.OTHER_SERVER_NAME, + event, + state={(e.type, e.state_key): e.event_id for e in current_state}, ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index ef5e25873c22..aaa3189b16ef 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -69,7 +69,7 @@ def prepare(self, reactor, clock, homeserver): def persist_event(self, event, state=None): """Persist the event, with optional state""" context = self.get_success( - self.state.compute_event_context(event, old_state=state) + self.state.compute_event_context(event, state_ids_before_event=state) ) self.get_success(self.persistence.persist_event(event, context)) @@ -103,9 +103,11 @@ def test_prune_gap(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -135,13 +137,14 @@ def test_do_not_prune_gap_if_state_different(self): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state(self.room_id)) + self.get_success(self.state.get_current_state_ids(self.room_id)) ) state_before_gap.pop(("m.room.history_visibility", "")) context = self.get_success( self.state.compute_event_context( - remote_event_2, old_state=state_before_gap.values() + remote_event_2, + state_ids_before_event=state_before_gap, ) ) @@ -177,9 +180,11 @@ def test_prune_gap_if_old(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -207,9 +212,11 @@ def test_do_not_prune_gap_if_other_server(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -247,9 +254,11 @@ def test_prune_gap_if_dummy_remote(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -289,9 +298,11 @@ def test_prune_gap_if_dummy_local(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) @@ -323,9 +334,11 @@ def test_do_not_prune_gap_if_not_dummy(self): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) diff --git a/tests/test_state.py b/tests/test_state.py index c6baea3d7604..84694d368d8b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -442,7 +442,12 @@ def test_annotate_with_old_message(self): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) @@ -467,7 +472,12 @@ def test_annotate_with_old_state(self): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) From 9376b9a982d5dfb7425e34931b4a68c78d8c8169 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 09:55:40 +0100 Subject: [PATCH 08/13] Newsfile --- changelog.d/12852.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12852.misc diff --git a/changelog.d/12852.misc b/changelog.d/12852.misc new file mode 100644 index 000000000000..afca32471fb1 --- /dev/null +++ b/changelog.d/12852.misc @@ -0,0 +1 @@ +Pull out less state when handling gaps in room DAG. From 71907fd632d3a69899d8b35c977b59b3504584a3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 10:12:20 +0100 Subject: [PATCH 09/13] Fix tests --- tests/handlers/test_federation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 183a6635fa8d..c529a3680a52 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -278,7 +278,9 @@ def test_backfill_with_many_backward_extremities(self) -> None: federation_event_handler._process_received_pdu( self.OTHER_SERVER_NAME, event, - state={(e.type, e.state_key): e.event_id for e in current_state}, + state_ids={ + (e.type, e.state_key): e.event_id for e in current_state + }, ) ) From f9ba81b360b78d5f5928eadf995dc176f50173c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 11:17:44 +0100 Subject: [PATCH 10/13] Apply suggestions from code review Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 4 ++-- synapse/state/__init__.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index bbc20b156b6c..6979dd081a5f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1623,7 +1623,7 @@ async def _check_for_soft_fail( event.room_id, room_version, state_sets, - event_map={}, + event_map=None, state_res_store=StateResolutionStore(self._store), ) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a9f9de555c0..955e6a83140d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1028,11 +1028,11 @@ async def create_new_client_event( for state_id in state_event_ids: data = metadata.get(state_id) if data is None: - raise Exception("State event not persisted %s", state_id) + raise Exception(f"State event {state_id} not persisted") if data.state_key is None: raise Exception( - "Trying to set non-state event as state: %s", state_id + f"Trying to set non-state event {state_id} as state" ) state_map[(data.event_type, data.state_key)] = state_id diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index ef444cac6b1f..335d408556bc 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -273,10 +273,10 @@ async def compute_event_context( Args: event: - state_ids_before_event: The event ids of the state at the event if + state_ids_before_event: The event ids of the state before the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we - don't have the prev events for, e.g. when backfilling. + don't have the prev events, e.g. when backfilling. partial_state: True if `state_ids_before_event` is partial and omits non-critical membership events Returns: From edcd824bdd4374f2656c7225ac94977664cc53dd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2022 11:35:54 +0100 Subject: [PATCH 11/13] Code review comments --- synapse/handlers/federation_event.py | 6 +++--- synapse/handlers/message.py | 27 ++++++++++++++++++++----- synapse/storage/databases/main/state.py | 2 +- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6979dd081a5f..43e79402e65a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -832,7 +832,7 @@ async def _resolve_state_at_missing_prevs( ) # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - event_map = {event_id: event} + try: # Get the state of the events we know about ours = await self._state_store.get_state_groups_ids(room_id, seen) @@ -865,7 +865,7 @@ async def _resolve_state_at_missing_prevs( room_id, room_version, state_maps, - event_map, + event_map={event_id: event}, state_res_store=StateResolutionStore(self._store), ) @@ -911,7 +911,7 @@ async def _get_state_ids_after_missing_prev_event( len(auth_event_ids), ) - # start by just trying to fetch the events from the store + # Start by checking events we already have in the DB desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 955e6a83140d..dd6f07dd1da5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,7 +55,14 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester +from synapse.types import ( + MutableStateMap, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache @@ -1024,22 +1031,32 @@ async def create_new_client_event( # old state is complete. metadata = await self.store.get_metadata_for_events(state_event_ids) - state_map = {} + state_map_for_event: MutableStateMap[str] = {} for state_id in state_event_ids: data = metadata.get(state_id) if data is None: - raise Exception(f"State event {state_id} not persisted") + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) if data.state_key is None: raise Exception( f"Trying to set non-state event {state_id} as state" ) - state_map[(data.event_type, data.state_key)] = state_id + state_map_for_event[(data.event_type, data.state_key)] = state_id context = await self.state.compute_event_context( event, - state_ids_before_event=state_map, + state_ids_before_event=state_map_for_event, ) else: context = await self.state.compute_event_context(event) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index e126020135fa..f2e79f01a77f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -182,7 +182,7 @@ def get_metadata_for_events_txn( result_map: Dict[str, EventMetadata] = {} for batch_ids in batch_iter(event_ids, 1000): - return await self.db_pool.runInteraction( + result_map = await self.db_pool.runInteraction( "get_metadata_for_events", get_metadata_for_events_txn, batch_ids=batch_ids, From 4343d35991f3a722115b60515851c39130d3d5ad Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 May 2022 10:15:50 +0100 Subject: [PATCH 12/13] Update synapse/storage/databases/main/state.py Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/databases/main/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index f2e79f01a77f..a2267e6d2c9c 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -182,11 +182,11 @@ def get_metadata_for_events_txn( result_map: Dict[str, EventMetadata] = {} for batch_ids in batch_iter(event_ids, 1000): - result_map = await self.db_pool.runInteraction( + result_map.update(await self.db_pool.runInteraction( "get_metadata_for_events", get_metadata_for_events_txn, batch_ids=batch_ids, - ) + )) return result_map From 9e432399179b1c3111c91b27f7550ac9f6b22428 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 May 2022 10:24:06 +0100 Subject: [PATCH 13/13] Fix style --- synapse/storage/databases/main/state.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a2267e6d2c9c..ea5cbdac08eb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -182,11 +182,13 @@ def get_metadata_for_events_txn( result_map: Dict[str, EventMetadata] = {} for batch_ids in batch_iter(event_ids, 1000): - result_map.update(await self.db_pool.runInteraction( - "get_metadata_for_events", - get_metadata_for_events_txn, - batch_ids=batch_ids, - )) + result_map.update( + await self.db_pool.runInteraction( + "get_metadata_for_events", + get_metadata_for_events_txn, + batch_ids=batch_ids, + ) + ) return result_map