diff --git a/changelog.d/12394.misc b/changelog.d/12394.misc new file mode 100644 index 000000000000..69109fcc37d3 --- /dev/null +++ b/changelog.d/12394.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: start a background process to resynchronise the room state after a room join. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 78d149905f52..e702a1644585 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -466,6 +466,8 @@ async def do_invite_join( ) if ret.partial_state: + # TODO(faster_joins): roll this back if we don't manage to start the + # background resync (eg process_remote_join fails) await self.store.store_partial_state_room(room_id, ret.servers_in_room) max_stream_id = await self._federation_event_handler.process_remote_join( @@ -478,6 +480,18 @@ async def do_invite_join( partial_state=ret.partial_state, ) + if ret.partial_state: + # kick off the process of asynchronously fixing up the state for this + # room + # + # TODO(faster_joins): pick this up again on restart + run_as_background_process( + desc="sync_partial_state_room", + func=self._sync_partial_state_room, + destination=origin, + room_id=room_id, + ) + # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. await self._replication.wait_for_stream_position( @@ -1370,3 +1384,61 @@ async def get_room_complexity( # We fell off the bottom, couldn't get the complexity from anyone. Oh # well. return None + + async def _sync_partial_state_room( + self, + destination: str, + room_id: str, + ) -> None: + """Background process to resync the state of a partial-state room + + Args: + destination: homeserver to pull the state from + room_id: room to be resynced + """ + + # TODO(faster_joins): do we need to lock to avoid races? What happens if other + # worker processes kick off a resync in parallel? Perhaps we should just elect + # a single worker to do the resync. + # + # TODO(faster_joins): what happens if we leave the room during a resync? if we + # really leave, that might mean we have difficulty getting the room state over + # federation. + # + # TODO(faster_joins): try other destinations if the one we have fails + + logger.info("Syncing state for room %s via %s", room_id, destination) + + # we work through the queue in order of increasing stream ordering. + while True: + batch = await self.store.get_partial_state_events_batch(room_id) + if not batch: + # all the events are updated, so we can update current state and + # clear the lazy-loading flag. + logger.info("Updating current state for %s", room_id) + assert ( + self.storage.persistence is not None + ), "TODO(faster_joins): support for workers" + await self.storage.persistence.update_current_state(room_id) + + logger.info("Clearing partial-state flag for %s", room_id) + success = await self.store.clear_partial_state_room(room_id) + if success: + logger.info("State resync complete for %s", room_id) + + # TODO(faster_joins) update room stats and user directory? + return + + # we raced against more events arriving with partial state. Go round + # the loop again. We've already logged a warning, so no need for more. + continue + + events = await self.store.get_events_as_list( + batch, + redact_behaviour=EventRedactBehaviour.AS_IS, + allow_rejected=True, + ) + for event in events: + await self._federation_event_handler.update_state_for_partial_state_event( + destination, event + ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 03c1197c997f..32bf02818c54 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -477,6 +477,45 @@ async def process_remote_join( return await self.persist_events_and_notify(room_id, [(event, context)]) + async def update_state_for_partial_state_event( + self, destination: str, event: EventBase + ) -> None: + """Recalculate the state at an event as part of a de-partial-stating process + + Args: + destination: server to request full state from + event: partial-state event to be de-partial-stated + """ + logger.info("Updating state for %s", event.event_id) + with nested_logging_context(suffix=event.event_id): + # if we have all the event's prev_events, then we can work out the + # state based on their states. Otherwise, we request it from the destination + # server. + # + # This is the same operation as we do when we receive a regular event + # over federation. + state = await self._resolve_state_at_missing_prevs(destination, event) + + # build a new state group for it if need be + context = await self._state_handler.compute_event_context( + event, + old_state=state, + ) + if context.partial_state: + # this can happen if some or all of the event's prev_events still have + # partial state - ie, an event has an earlier stream_ordering than one + # or more of its prev_events, so we de-partial-state it before its + # prev_events. + # + # TODO(faster_joins): we probably need to be more intelligent, and + # exclude partial-state prev_events from consideration + logger.warning( + "%s still has partial state: can't de-partial-state it yet", + event.event_id, + ) + return + await self._store.update_state_for_partial_state_event(event, context) + async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 57489c30f142..1e7c644f7051 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -963,6 +963,21 @@ def _persist_transaction_ids_txn( values=to_insert, ) + async def update_current_state( + self, + room_id: str, + state_delta: DeltaState, + stream_id: int, + ) -> None: + """Update the current state stored in the datatabase for the given room""" + + await self.db_pool.runInteraction( + "update_current_state", + self._update_current_state_txn, + state_delta_by_room={room_id: state_delta}, + stream_id=stream_id, + ) + def _update_current_state_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a60e3f4fdde0..9fd7697152c1 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1979,3 +1979,27 @@ async def is_partial_state_event(self, event_id: str) -> bool: desc="is_partial_state_event", ) return result is not None + + async def get_partial_state_events_batch(self, room_id: str) -> List[str]: + """Get a list of events in the given room that have partial state""" + return await self.db_pool.runInteraction( + "get_partial_state_events_batch", + self._get_partial_state_events_batch_txn, + room_id, + ) + + @staticmethod + def _get_partial_state_events_batch_txn( + txn: LoggingTransaction, room_id: str + ) -> List[str]: + txn.execute( + """\ + SELECT event_id FROM partial_state_events AS pse + JOIN events USING (event_id) + WHERE pse.room_id = ? + ORDER BY events.stream_ordering + LIMIT 100 + """, + (room_id,), + ) + return [row[0] for row in txn] diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 18b1acd9e113..526dceed79e8 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1077,6 +1077,37 @@ def get_rooms_for_retention_period_in_range_txn( get_rooms_for_retention_period_in_range_txn, ) + async def clear_partial_state_room(self, room_id: str) -> bool: + # this can race with incoming events, so we watch out for FK errors. + # todo: this still doesn't completely fix the race, since the persist process + # is not atomic. I fear we need an application-level lock. + try: + await self.db_pool.runInteraction( + "clear_partial_state_room", self._clear_partial_state_room_txn, room_id + ) + return True + except self.db_pool.engine.module.DatabaseError as e: + # todo: how do we distinguish between FK errors and other errors? + logger.warning( + "Exception while clearing lazy partial-state-room %s, retrying: %s", + room_id, + e, + ) + return False + + @staticmethod + def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None: + DatabasePool.simple_delete_txn( + txn, + table="partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + ) + DatabasePool.simple_delete_one_txn( + txn, + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + ) + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4a461a0abb1f..5bedc8df9e8b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -21,6 +21,7 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -354,6 +355,53 @@ async def get_referenced_state_groups( return {row["state_group"] for row in rows} + async def update_state_for_partial_state_event( + self, + event: EventBase, + context: EventContext, + ) -> None: + """Update the state group for a partial state event""" + await self.db_pool.runInteraction( + "update_state_for_partial_state_event", + self._update_state_for_partial_state_event_txn, + event, + context, + ) + + def _update_state_for_partial_state_event_txn( + self, + txn, + event: EventBase, + context: EventContext, + ): + # we shouldn't have any outliers here + assert not event.internal_metadata.is_outlier() + + # anything that was rejected should have the same state as its + # predecessor. + if context.rejected: + assert context.state_group == context.state_group_before_event + + self.db_pool.simple_update_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event.event_id}, + updatevalues={"state_group": context.state_group}, + ) + + self.db_pool.simple_delete_one_txn( + txn, + table="partial_state_events", + keyvalues={"event_id": event.event_id}, + ) + + # TODO: need to do something about workers here + txn.call_after( + self._get_state_group_for_event.prefill, + (event.event_id,), + context.state_group, + ) + class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index b40292281767..e496ba7bed6e 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -376,6 +376,62 @@ async def persist_event( pos = PersistedEventPosition(self._instance_name, event_stream_id) return event, pos, self.main_store.get_room_max_token() + async def update_current_state(self, room_id: str) -> None: + """Recalculate the current state for a room, and persist it""" + state = await self._calculate_current_state(room_id) + delta = await self._calculate_state_delta(room_id, state) + + # TODO(faster_joins): get a real stream ordering, to make this work correctly + # across workers. + # + # TODO(faster_joins): this can race against event persistence, in which case we + # will end up with incorrect state. Perhaps we should make this a job we + # farm out to the event persister, somehow. + stream_id = self.main_store.get_room_max_stream_ordering() + await self.persist_events_store.update_current_state(room_id, delta, stream_id) + + async def _calculate_current_state(self, room_id: str) -> StateMap[str]: + """Calculate the current state of a room, based on the forward extremities + + Args: + room_id: room for which to calculate current state + + Returns: + map from (type, state_key) to event id for the current state in the room + """ + latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + state_groups = set( + ( + await self.main_store._get_state_group_for_events(latest_event_ids) + ).values() + ) + + state_maps_by_state_group = await self.state_store._get_state_for_groups( + state_groups + ) + + if len(state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_maps_by_state_group[state_groups.pop()] + + # Ok, we need to defer to the state handler to resolve our state sets. + logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + + room_version = await self.main_store.get_room_version_id(room_id) + res = await self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_maps_by_state_group, + event_map=None, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state + async def _persist_event_batch( self, events_and_contexts: List[Tuple[EventBase, EventContext]],