From 0a3ccdc64e846d74c9eee41e8c867079749aaba4 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 May 2022 07:45:31 +0200 Subject: [PATCH 1/7] Add some type hints to `event_federation` datastore --- changelog.d/12752.misc | 1 + mypy.ini | 1 - .../databases/main/event_federation.py | 184 ++++++++++++------ 3 files changed, 121 insertions(+), 65 deletions(-) create mode 100644 changelog.d/12752.misc diff --git a/changelog.d/12752.misc b/changelog.d/12752.misc new file mode 100644 index 000000000000..e793d08e5e3f --- /dev/null +++ b/changelog.d/12752.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 9ae7ad211c54..cbdb2507821a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,7 +27,6 @@ exclude = (?x) |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py - |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/roommember.py |synapse/storage/schema/ diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 471022470843..93e2aa174c9e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,17 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) import attr from prometheus_client import Counter, Gauge @@ -33,7 +43,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -135,7 +145,7 @@ async def get_auth_chain_ids( # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -158,7 +168,11 @@ async def get_auth_chain_ids( ) def _get_auth_chain_ids_using_cover_index_txn( - self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + self, + txn: LoggingTransaction, + room_id: str, + event_ids: Collection[str], + include_given: bool, ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" @@ -215,9 +229,9 @@ def _get_auth_chain_ids_using_cover_index_txn( chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. - for batch in batch_iter(event_chains, 1000): + for batch2 in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -297,7 +311,7 @@ def _get_auth_chain_ids_txn( front = set(event_ids) while front: - new_front = set() + new_front: Set[str] = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. to_fetch: List[str] = [] # Event IDs to fetch from DB @@ -316,7 +330,7 @@ def _get_auth_chain_ids_txn( # Note we need to batch up the results by event ID before # adding to the cache. - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -349,7 +363,7 @@ async def get_auth_chain_difference( # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -370,7 +384,7 @@ async def get_auth_chain_difference( ) def _get_auth_chain_difference_using_cover_index_txn( - self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using the chain index. @@ -444,9 +458,9 @@ def _get_auth_chain_difference_using_cover_index_txn( # (We need to take a copy of `seen_chains` as we want to mutate it in # the loop) - for batch in batch_iter(set(seen_chains), 1000): + for batch2 in batch_iter(set(seen_chains), 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -529,7 +543,7 @@ def _get_auth_chain_difference_using_cover_index_txn( return result def _get_auth_chain_difference_txn( - self, txn, state_sets: List[Set[str]] + self, txn: LoggingTransaction, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using a breadth first search. @@ -602,7 +616,7 @@ def _get_auth_chain_difference_txn( # I think building a temporary list with fetchall is more efficient than # just `search.extend(txn)`, but this is unconfirmed - search.extend(txn.fetchall()) + search.extend(cast(List[Tuple[int, str]], txn.fetchall())) # sort by depth search.sort() @@ -645,7 +659,7 @@ def _get_auth_chain_difference_txn( # We parse the results and add the to the `found` set and the # cache (note we need to batch up the results by event ID before # adding to the cache). - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -696,7 +710,7 @@ def _get_auth_chain_difference_txn( return {eid for eid, n in event_to_missing_sets.items() if n} async def get_oldest_event_ids_with_depth_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Gets the oldest events(backwards extremities) in the room along with the aproximate depth. @@ -713,7 +727,9 @@ async def get_oldest_event_ids_with_depth_in_room( List of (event_id, depth) tuples """ - def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + def get_oldest_event_ids_with_depth_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: # Assemble a dictionary with event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because @@ -743,7 +759,7 @@ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): txn.execute(sql, (room_id, False)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_oldest_event_ids_with_depth_in_room", @@ -752,7 +768,7 @@ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): ) async def get_insertion_event_backward_extremities_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Get the insertion events we know about that we haven't backfilled yet. @@ -768,7 +784,9 @@ async def get_insertion_event_backward_extremities_in_room( List of (event_id, depth) tuples """ - def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): + def get_insertion_event_backward_extremities_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: sql = """ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i /* We only want insertion events that are also marked as backwards extremities */ @@ -780,7 +798,7 @@ def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): """ txn.execute(sql, (room_id,)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_insertion_event_backward_extremities_in_room", @@ -788,7 +806,7 @@ def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -817,7 +835,7 @@ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: return max_depth_event_id, current_max_depth - async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the min depth from a set of event IDs Args: @@ -865,7 +883,9 @@ async def get_prev_events_for_room(self, room_id: str) -> List[str]: "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) - def _get_prev_events_for_room_txn(self, txn, room_id: str): + def _get_prev_events_for_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> List[str]: # we just use the 10 newest events. Older events will become # prev_events of future events. @@ -896,7 +916,7 @@ async def get_rooms_with_many_extremities( sorted by extremity count. """ - def _get_rooms_with_many_extremities_txn(txn): + def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]: where_clause = "1=1" if room_id_filter: where_clause = "room_id NOT IN (%s)" % ( @@ -937,7 +957,9 @@ async def get_min_depth(self, room_id: str) -> Optional[int]: "get_min_depth", self._get_min_depth_interaction, room_id ) - def _get_min_depth_interaction(self, txn, room_id): + def _get_min_depth_interaction( + self, txn: LoggingTransaction, room_id: str + ) -> Optional[int]: min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", @@ -966,22 +988,24 @@ async def get_forward_extremities_for_room_at_stream_ordering( """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a # stream_ordering from before a restart - last_change = max(self._stream_order_on_start, last_change) + last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined] # provided the last_change is recent enough, we now clamp the requested # stream_ordering to it. - if last_change > self.stream_ordering_month_ago: + if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined] stream_ordering = min(last_change, stream_ordering) return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -989,7 +1013,7 @@ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): stream_orderings from that point. """ - if stream_ordering <= self.stream_ordering_month_ago: + if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined] raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ @@ -1002,7 +1026,7 @@ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): WHERE room_id = ? """ - def get_forward_extremeties_for_room_txn(txn): + def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] @@ -1104,8 +1128,8 @@ def _get_connected_prev_event_backfill_results_txn( ] async def get_backfill_events( - self, room_id: str, seed_event_id_list: list, limit: int - ): + self, room_id: str, seed_event_id_list: List[str], limit: int + ) -> List[EventBase]: """Get a list of Events for a given topic that occurred before (and including) the events in seed_event_id_list. Return a list of max size `limit` @@ -1123,10 +1147,16 @@ async def get_backfill_events( ) events = await self.get_events_as_list(event_ids) return sorted( - events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) # type: ignore[operator] ) - def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): + def _get_backfill_events( + self, + txn: LoggingTransaction, + room_id: str, + seed_event_id_list: List[str], + limit: int, + ) -> Set[str]: """ We want to make sure that we do a breadth-first, "depth" ordered search. We also handle navigating historical branches of history connected by @@ -1139,7 +1169,7 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): limit, ) - event_id_results = set() + event_id_results: Set[str] = set() # In a PriorityQueue, the lowest valued entries are retrieved first. # We're using depth as the priority in the queue and tie-break based on @@ -1147,7 +1177,7 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): # highest and newest-in-time message. We add events to the queue with a # negative depth so that we process the newest-in-time messages first # going backwards in time. stream_ordering follows the same pattern. - queue = PriorityQueue() + queue = PriorityQueue() # type: ignore[var-annotated] for seed_event_id in seed_event_id_list: event_lookup_result = self.db_pool.simple_select_one_txn( @@ -1253,7 +1283,13 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): return event_id_results - async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + async def get_missing_events( + self, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, @@ -1264,11 +1300,18 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi ) return await self.get_events_as_list(ids) - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): + def _get_missing_events( + self, + txn: LoggingTransaction, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[str]: seen_events = set(earliest_events) front = set(latest_events) - seen_events - event_results = [] + event_results: List[str] = [] query = ( "SELECT prev_event_id FROM event_edges " @@ -1311,7 +1354,7 @@ async def get_successor_events(self, event_id: str) -> List[str]: @wrap_as_background_process("delete_old_forward_extrem_cache") async def _delete_old_forward_extrem_cache(self) -> None: - def _delete_old_forward_extrem_cache_txn(txn): + def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None: # Delete entries older than a month, while making sure we don't delete # the only entries for a room. sql = """ @@ -1324,7 +1367,7 @@ def _delete_old_forward_extrem_cache_txn(txn): ) AND stream_ordering < ? """ txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined] ) await self.db_pool.runInteraction( @@ -1382,7 +1425,9 @@ async def remove_received_event_from_staging( """ if self.db_pool.engine.supports_returning: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: sql = """ DELETE FROM federation_inbound_events_staging WHERE origin = ? AND event_id = ? @@ -1390,21 +1435,24 @@ def _remove_received_event_from_staging_txn(txn): """ txn.execute(sql, (origin, event_id)) - return txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) - row = await self.db_pool.runInteraction( + if row is None: + return None + + return row[0] + + return await self.db_pool.runInteraction( "remove_received_event_from_staging", _remove_received_event_from_staging_txn, db_autocommit=True, ) - if row is None: - return None - - return row[0] else: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: received_ts = self.db_pool.simple_select_one_onecol_txn( txn, table="federation_inbound_events_staging", @@ -1437,7 +1485,9 @@ async def get_next_staged_event_id_for_room( ) -> Optional[Tuple[str, str]]: """Get the next event ID in the staging area for the given room.""" - def _get_next_staged_event_id_for_room_txn(txn): + def _get_next_staged_event_id_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str]]: sql = """ SELECT origin, event_id FROM federation_inbound_events_staging @@ -1448,7 +1498,7 @@ def _get_next_staged_event_id_for_room_txn(txn): txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str]], txn.fetchone()) return await self.db_pool.runInteraction( "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn @@ -1461,7 +1511,9 @@ async def get_next_staged_event_for_room( ) -> Optional[Tuple[str, EventBase]]: """Get the next event in the staging area for the given room.""" - def _get_next_staged_event_for_room_txn(txn): + def _get_next_staged_event_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str, str]]: sql = """ SELECT event_json, internal_metadata, origin FROM federation_inbound_events_staging @@ -1471,7 +1523,7 @@ def _get_next_staged_event_for_room_txn(txn): """ txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str, str]], txn.fetchone()) row = await self.db_pool.runInteraction( "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn @@ -1599,18 +1651,20 @@ async def get_all_rooms_with_staged_incoming_events(self) -> List[str]: ) @wrap_as_background_process("_get_stats_for_federation_staging") - async def _get_stats_for_federation_staging(self): + async def _get_stats_for_federation_staging(self) -> None: """Update the prometheus metrics for the inbound federation staging area.""" - def _get_stats_for_federation_staging_txn(txn): + def _get_stats_for_federation_staging_txn( + txn: LoggingTransaction, + ) -> Tuple[int, int]: txn.execute("SELECT count(*) FROM federation_inbound_events_staging") - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) txn.execute( "SELECT min(received_ts) FROM federation_inbound_events_staging" ) - (received_ts,) = txn.fetchone() + (received_ts,) = cast(Tuple[int], txn.fetchone()) # If there is nothing in the staging area default it to 0. age = 0 @@ -1651,19 +1705,21 @@ def __init__( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) - async def clean_room_for_join(self, room_id): - return await self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id: str) -> None: + await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) - def _clean_room_for_join_txn(self, txn, room_id): + def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None: query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - async def _background_delete_non_state_event_auth(self, progress, batch_size): - def delete_event_auth(txn): + async def _background_delete_non_state_event_auth( + self, progress: JsonDict, batch_size: int + ) -> int: + def delete_event_auth(txn: LoggingTransaction) -> bool: target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") From c4e1338e2627b2d234e4c8fd6b2f481de3f653d7 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 May 2022 07:48:46 +0200 Subject: [PATCH 2/7] update newsfile --- changelog.d/{12752.misc => 12753.misc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changelog.d/{12752.misc => 12753.misc} (100%) diff --git a/changelog.d/12752.misc b/changelog.d/12753.misc similarity index 100% rename from changelog.d/12752.misc rename to changelog.d/12753.misc From 45eed216b2ed45671b0ceded62cc1905cd4a411b Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 May 2022 08:00:51 +0200 Subject: [PATCH 3/7] update `room_batch` handler --- synapse/handlers/room_batch.py | 2 ++ tests/handlers/test_federation.py | 1 + 2 files changed, 3 insertions(+) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 29de7e5bed10..fbfd7484065c 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -53,6 +53,7 @@ async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: # We want to use the successor event depth so they appear after `prev_event` because # it has a larger `depth` but before the successor event because the `stream_ordering` # is negative before the successor event. + assert most_recent_prev_event_id is not None successor_event_ids = await self.store.get_successor_events( most_recent_prev_event_id ) @@ -139,6 +140,7 @@ async def get_most_recent_full_state_ids_from_event_id_list( _, ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id + assert most_recent_event_id is not None prev_state_map = await self.state_store.get_state_ids_for_event( most_recent_event_id ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 060ba5f5174f..e95dfdce2086 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -332,6 +332,7 @@ def test_backfill_floating_outlier_membership_auth(self) -> None: most_recent_prev_event_depth, ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) # mapping from (type, state_key) -> state_event_id + assert most_recent_prev_event_id is not None prev_state_map = self.get_success( self.state_store.get_state_ids_for_event(most_recent_prev_event_id) ) From c55fb12d98d6cc60dd9ef05e19c6670e6c8ce515 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 18 May 2022 07:41:12 +0200 Subject: [PATCH 4/7] Apply suggestions from code review Co-authored-by: David Robertson --- synapse/storage/databases/main/event_federation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 93e2aa174c9e..d563654749e6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1147,6 +1147,8 @@ async def get_backfill_events( ) events = await self.get_events_as_list(event_ids) return sorted( + # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering. + # But it's never None, because these events were previously persisted to the DB. events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) # type: ignore[operator] ) From 27bbf22fe03bb2916b7b5208a806a6a9c5a3eff5 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 18 May 2022 08:05:00 +0200 Subject: [PATCH 5/7] update after review --- .../storage/databases/main/event_federation.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index d563654749e6..a3070ad80cd2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -41,7 +41,9 @@ LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict from synapse.util import json_encoder @@ -85,7 +87,13 @@ def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) -class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore): +class EventFederationWorkerStore( + SignatureWorkerStore, + EventsWorkerStore, + SQLBaseStore, + RoomWorkerStore, + StreamWorkerStore, +): def __init__( self, database: DatabasePool, @@ -1149,7 +1157,8 @@ async def get_backfill_events( return sorted( # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering. # But it's never None, because these events were previously persisted to the DB. - events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) # type: ignore[operator] + events, + key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator] ) def _get_backfill_events( @@ -1179,7 +1188,7 @@ def _get_backfill_events( # highest and newest-in-time message. We add events to the queue with a # negative depth so that we process the newest-in-time messages first # going backwards in time. stream_ordering follows the same pattern. - queue = PriorityQueue() # type: ignore[var-annotated] + queue: "PriorityQueue[Tuple[int, int, str, str]]" for seed_event_id in seed_event_id_list: event_lookup_result = self.db_pool.simple_select_one_txn( @@ -1666,7 +1675,7 @@ def _get_stats_for_federation_staging_txn( "SELECT min(received_ts) FROM federation_inbound_events_staging" ) - (received_ts,) = cast(Tuple[int], txn.fetchone()) + (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone()) # If there is nothing in the staging area default it to 0. age = 0 From 7b5c278943d206d631f1bfca74e4e7cf3989452f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 18 May 2022 09:28:43 +0200 Subject: [PATCH 6/7] Revert inheritance --- synapse/storage/databases/main/event_federation.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index a3070ad80cd2..b0232c536fc4 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -87,13 +87,7 @@ def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) -class EventFederationWorkerStore( - SignatureWorkerStore, - EventsWorkerStore, - SQLBaseStore, - RoomWorkerStore, - StreamWorkerStore, -): +class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -1188,7 +1182,7 @@ def _get_backfill_events( # highest and newest-in-time message. We add events to the queue with a # negative depth so that we process the newest-in-time messages first # going backwards in time. stream_ordering follows the same pattern. - queue: "PriorityQueue[Tuple[int, int, str, str]]" + queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue() for seed_event_id in seed_event_id_list: event_lookup_result = self.db_pool.simple_select_one_txn( From af46ee202c79959f58735ecb95ddd436c4714717 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 18 May 2022 09:34:21 +0200 Subject: [PATCH 7/7] Remove imports --- synapse/storage/databases/main/event_federation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index b0232c536fc4..dcfe8caf473a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -41,9 +41,7 @@ LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict from synapse.util import json_encoder