From 9f409cb41a65421770627c099ea6a8fb0b74273d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 21:17:25 +0000 Subject: [PATCH 01/14] Add docstrings to `AbstractStreamIdGenerator` methods --- synapse/storage/util/id_generators.py | 59 +++++++++------------------ 1 file changed, 20 insertions(+), 39 deletions(-) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ac56bc9a050f..011de4a41fb7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -92,18 +92,38 @@ def _load_current_id( class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ raise NotImplementedError() @abc.abstractmethod def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ raise NotImplementedError() @abc.abstractmethod def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + + Returns: + The maximum stream id. + """ raise NotImplementedError() @abc.abstractmethod def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to `get_current_token`. + """ raise NotImplementedError() @@ -158,11 +178,6 @@ def __init__( self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ with self._lock: self._current += self._step next_id = self._current @@ -180,11 +195,6 @@ def manager() -> Generator[int, None, None]: return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... - """ with self._lock: next_ids = range( self._current + self._step, @@ -208,12 +218,6 @@ def manager() -> Generator[Sequence[int], None, None]: return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - - Returns: - The maximum stream id. - """ with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step @@ -221,11 +225,6 @@ def get_current_token(self) -> int: return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() @@ -475,12 +474,6 @@ def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: return stream_ids def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -492,12 +485,6 @@ def get_next(self) -> AsyncContextManager[int]: return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: - """ - Usage: - async with stream_id_gen.get_next_mult(5) as stream_ids: - # ... persist events ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -597,15 +584,9 @@ def _mark_id_as_finished(self, next_id: int) -> None: self._add_persisted_position(next_id) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - """ - return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer.""" - # If we don't have an entry for the given instance name, we assume it's a # new writer. # From 2ce1c29e6d03affcef1b11ac8da90b07a83a7160 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 21:23:36 +0000 Subject: [PATCH 02/14] Create `AbstractStreamIdTracker` base class And add a dummy `advance` method to `StreamIdGenerator`. Note that this is no worse than what would currently happen if `advance` were to be called on `StreamIdGenerator`. --- .../slave/storage/_slaved_id_tracker.py | 14 +----- synapse/storage/util/id_generators.py | 46 +++++++++++-------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 8c1bf9227ac6..6191bcede3eb 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -14,10 +14,10 @@ from typing import List, Optional, Tuple from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import _load_current_id +from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id -class SlavedIdTracker: +class SlavedIdTracker(AbstractStreamIdTracker): def __init__( self, db_conn: LoggingDatabaseConnection, @@ -36,17 +36,7 @@ def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: - """ - - Returns: - int - """ return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 011de4a41fb7..98a6d7120e7b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -89,22 +89,13 @@ def _load_current_id( return (max if step > 0 else min)(current_id, step) -class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ - raise NotImplementedError() +class AbstractStreamIdTracker(metaclass=abc.ABCMeta): + """Tracks a stream that can have multiple writers.""" @abc.abstractmethod - def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... + def advance(self, instance_name: str, new_id: int) -> None: + """Advance the position of the named writer to the given ID, if greater + than existing entry. """ raise NotImplementedError() @@ -127,6 +118,26 @@ def get_current_token_for_writer(self, instance_name: str) -> int: raise NotImplementedError() +class AbstractStreamIdGenerator(AbstractStreamIdTracker): + @abc.abstractmethod + def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ + raise NotImplementedError() + + class StreamIdGenerator(AbstractStreamIdGenerator): """Used to generate new stream ids when persisting events while keeping track of which transactions have been completed. @@ -177,6 +188,9 @@ def __init__( # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() + def advance(self, instance_name: str, new_id: int) -> None: + raise Exception("Replication is not supported by StreamIdGenerator") + def get_next(self) -> AsyncContextManager[int]: with self._lock: self._current += self._step @@ -612,10 +626,6 @@ def get_positions(self) -> Dict[str, int]: } def advance(self, instance_name: str, new_id: int) -> None: - """Advance the position of the named writer to the given ID, if greater - than existing entry. - """ - new_id *= self._return_factor with self._lock: From 6122b3073a8dc21a5358a2b3391c753a5e9b25ab Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 21:43:47 +0000 Subject: [PATCH 03/14] Make `_push_rules_stream_id_gen` an `AbstractStreamIdTracker` --- synapse/replication/slave/storage/push_rule.py | 4 ---- synapse/storage/databases/main/push_rule.py | 11 +++++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 4d5f86286213..7541e21de9dd 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -25,9 +24,6 @@ def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - # We assert this for the benefit of mypy - assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) - if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fa782023d4ee..3b63267395c3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -28,7 +28,10 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -82,9 +85,9 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: Union[ - StreamIdGenerator, SlavedIdTracker - ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" From a1c5c4852034fbc299552f5b54b497ad54170b45 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 21:50:50 +0000 Subject: [PATCH 04/14] Remove duplicate `_EventCacheEntry` definition --- synapse/storage/databases/main/events.py | 12 ++++++------ synapse/storage/databases/main/events_worker.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 06832221adc2..d5cc446216e6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict, namedtuple +from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, @@ -41,6 +41,7 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -64,9 +65,6 @@ ) -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -1553,11 +1551,13 @@ def _add_to_cache(self, txn, events_and_contexts): for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) + self.store._get_event_cache.set( + (cache_entry.event.event_id,), cache_entry + ) txn.call_after(prefill) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c6bf316d5bf3..668b0d8a9c89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -89,7 +89,7 @@ @attr.s(slots=True, auto_attribs=True) -class _EventCacheEntry: +class EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -223,7 +223,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, _EventCacheEntry]] + str, ObservableDeferred[Dict[str, EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() @@ -544,7 +544,7 @@ async def get_events_as_list( async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -578,7 +578,7 @@ async def _get_events_from_cache_or_db( # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, _EventCacheEntry]] + ObservableDeferred[Dict[str, EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -601,7 +601,7 @@ async def _get_events_from_cache_or_db( # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, _EventCacheEntry] + Dict[str, EventCacheEntry] ] = ObservableDeferred(defer.Deferred()) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -663,7 +663,7 @@ def _invalidate_get_event_cache(self, event_id): def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -815,7 +815,7 @@ def fire(evs, exc): async def _get_events_from_db( self, event_ids: Iterable[str] - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -958,7 +958,7 @@ async def _get_events_from_db( original_ev, redactions, event_map ) - cache_entry = _EventCacheEntry( + cache_entry = EventCacheEntry( event=original_ev, redacted_event=redacted_event ) From c463c9df71a959c149c9027d3fcb48d096178d20 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 21:55:31 +0000 Subject: [PATCH 05/14] Fix type hints for `SQLBaseStore.process_replication_rows` --- synapse/storage/_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0623da9aa196..3056e64ff570 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -21,7 +21,7 @@ from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import StreamToken, get_domain_from_id +from synapse.types import get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -48,7 +48,7 @@ def process_replication_rows( self, stream_name: str, instance_name: str, - token: StreamToken, + token: int, rows: Iterable[Any], ) -> None: pass From ba8af6749d5c0e175d206c8bd32db0611aec0813 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 22:11:08 +0000 Subject: [PATCH 06/14] Fix tpye hint for `_EventRow.room_version_id` `room_version_id` comes from the `rooms` table, where it is a string instead of an int. --- synapse/storage/databases/main/events_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 668b0d8a9c89..6110b7dd8ca2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -129,7 +129,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[int] + room_version_id: Optional[str] rejected_reason: Optional[str] redactions: List[str] outlier: bool From 6cf127bf533271aa772fef51a9ed2adba3c4366a Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 22:19:09 +0000 Subject: [PATCH 07/14] `_enqueue_events` needs a `Collection` since it uses `len()` --- synapse/storage/databases/main/events_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6110b7dd8ca2..f2d661197c54 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -814,7 +814,7 @@ def fire(evs, exc): self.hs.get_reactor().callFromThread(fire, event_list, e) async def _get_events_from_db( - self, event_ids: Iterable[str] + self, event_ids: Collection[str] ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the database. @@ -967,7 +967,7 @@ async def _get_events_from_db( return result_map - async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. From de0585a322934097f30405d6e3e581f12bf24e17 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 22:22:15 +0000 Subject: [PATCH 08/14] `get_events_as_list` needs a `Collection` --- synapse/state/__init__.py | 2 +- synapse/state/v1.py | 3 ++- synapse/storage/databases/main/events_worker.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1605411b0087..446204dbe52f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -764,7 +764,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Iterable[str], allow_rejected: bool = False + self, event_ids: Collection[str], allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 6edadea550d2..499a32820185 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,6 +17,7 @@ from typing import ( Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -44,7 +45,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f2d661197c54..38a3e59d5778 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -357,7 +357,7 @@ async def get_event( async def get_events( self, - event_ids: Iterable[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, From 5de79f107c0f1ca5863b2e009ece4d110dd0dde6 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 22 Nov 2021 22:05:11 +0000 Subject: [PATCH 09/14] Add type hints to `synapse/storage/databases/main/events_worker.py` --- mypy.ini | 4 +- synapse/replication/tcp/streams/events.py | 8 +- synapse/storage/databases/main/events.py | 17 +- .../storage/databases/main/events_worker.py | 174 ++++++++++++------ .../test_sharded_event_persister.py | 6 +- 5 files changed, 142 insertions(+), 67 deletions(-) diff --git a/mypy.ini b/mypy.ini index 308cfd95d847..0a236b6ab4c1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,7 +33,6 @@ exclude = (?x) |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/events_bg_updates.py - |synapse/storage/databases/main/events_worker.py |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py @@ -181,6 +180,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.events_worker] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a030e9299edc..feb5cb12b516 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -157,7 +157,9 @@ async def _update_function( # now we fetch up to that many rows from the events table - event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( + event_rows: List[ + Tuple[int, str, str, str, str, str, str, str, str] + ] = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count ) @@ -191,7 +193,9 @@ async def _update_function( # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows: List[ + Tuple[int, str, str, str, str, str, str, str, str] + ] = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d5cc446216e6..c3440de2cbcb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -44,7 +44,7 @@ from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -106,16 +106,21 @@ def __init__( self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen - # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" + # Since we have been configured to write, we ought to have id generators, + # rather than id trackers. + assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) + assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) + + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 38a3e59d5778..d6628ce2c840 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -15,14 +15,18 @@ import logging import threading from typing import ( + TYPE_CHECKING, + Any, Collection, Container, Dict, Iterable, List, + NoReturn, Optional, Set, Tuple, + cast, overload, ) @@ -38,6 +42,7 @@ from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -56,10 +61,18 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -69,6 +82,9 @@ from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) + self._stream_id_gen: AbstractStreamIdTracker + self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -214,7 +237,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): 5 * 60 * 1000, ) - self._get_event_cache = LruCache( + self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -227,15 +250,17 @@ def __init__(self, database: DatabasePool, db_conn, hs): ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] + self._event_fetch_list: List[ + Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] + ] = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): + def get_chain_id_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -246,7 +271,13 @@ def get_chain_id_txn(txn): id_column="chain_id", ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -294,7 +325,7 @@ async def get_event( redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - allow_none: Literal[True] = False, + allow_none: Literal[True] = True, check_room_id: Optional[str] = None, ) -> Optional[EventBase]: ... @@ -658,7 +689,7 @@ async def _get_events_from_cache_or_db( return event_entry_map - def _invalidate_get_event_cache(self, event_id): + def _invalidate_get_event_cache(self, event_id: str) -> None: self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( @@ -736,7 +767,7 @@ async def get_stripped_room_state_from_event_context( for e in state_to_include.values() ] - def _do_fetch(self, conn: Connection) -> None: + def _do_fetch(self, conn: LoggingDatabaseConnection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ @@ -767,7 +798,9 @@ def _do_fetch(self, conn: Connection) -> None: event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( - self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] + self, + conn: LoggingDatabaseConnection, + event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -794,7 +827,7 @@ def _fetch_event_list( ) # We only want to resolve deferreds from the main thread - def fire(): + def fire() -> None: for _, d in event_list: d.callback(row_dict) @@ -804,14 +837,14 @@ def fire(): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs, exc): - for _, d in evs: + def fire_errback(exc: Exception) -> None: + for _, d in event_list: if not d.called: with PreserveLoggingContext(): d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, e) + self.hs.get_reactor().callFromThread(fire_errback, e) async def _get_events_from_db( self, event_ids: Collection[str] @@ -831,29 +864,29 @@ async def _get_events_from_db( map from event id to result. May return extra events which weren't asked for. """ - fetched_events = {} + fetched_event_ids: Set[str] = set() + fetched_events: Dict[str, _EventRow] = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids = set() + redaction_ids: Set[str] = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_events[event_id] = row + fetched_event_ids.add(event_id) if row: + fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_events.keys()) + events_to_fetch = redaction_ids.difference(fetched_event_ids) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map = {} + event_map: Dict[str, EventBase] = {} for event_id, row in fetched_events.items(): - if not row: - continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -881,6 +914,7 @@ async def _get_events_from_db( room_version_id = row.room_version_id + room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -951,7 +985,7 @@ async def _get_events_from_db( # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map = {} + result_map: Dict[str, EventCacheEntry] = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( @@ -980,7 +1014,7 @@ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow] that weren't requested. """ - events_d = defer.Deferred() + events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) @@ -1146,7 +1180,7 @@ def _maybe_redact_event_row( # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1175,7 +1209,7 @@ async def have_seen_events( event_ids: events we are looking for Returns: - set[str]: The events we have already seen. + The set of events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1198,7 +1232,9 @@ async def _have_seen_events_dict( } results = {x: True for x in cache_results} - def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): + def have_seen_events_txn( + txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] + ) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1224,12 +1260,14 @@ def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str): + async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn(self, txn, room_id): + def _get_current_state_event_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> int: """ See get_current_state_event_counts. """ @@ -1254,7 +1292,7 @@ async def get_current_state_event_counts(self, room_id: str) -> int: room_id, ) - async def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id: str) -> Dict[str, float]: """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1262,10 +1300,10 @@ async def get_room_complexity(self, room_id): more resources. Args: - room_id (str) + room_id: The room ID to query. Returns: - dict[str:int] of complexity version to complexity. + dict[str:float] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1275,13 +1313,13 @@ async def get_room_complexity(self, room_id): return {"v1": complexity_v1} - def get_current_events_token(self): + def get_current_events_token(self) -> int: """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns new events, for the Events replication stream Args: @@ -1295,7 +1333,9 @@ async def get_all_new_forward_event_rows( EventsStreamRow. """ - def get_all_new_forward_event_rows(txn): + def get_all_new_forward_event_rows( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" @@ -1311,7 +1351,9 @@ def get_all_new_forward_event_rows(txn): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1319,7 +1361,7 @@ def get_all_new_forward_event_rows(txn): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1332,7 +1374,9 @@ async def get_ex_outlier_stream_rows( EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn(txn): + def get_ex_outlier_stream_rows_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" @@ -1350,7 +1394,9 @@ def get_ex_outlier_stream_rows_txn(txn): ) txn.execute(sql, (last_id, current_id, instance_name)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1358,7 +1404,7 @@ def get_ex_outlier_stream_rows_txn(txn): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1386,7 +1432,9 @@ async def get_all_new_backfill_event_rows( if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows(txn): + def get_all_new_backfill_event_rows( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" @@ -1400,7 +1448,10 @@ def get_all_new_backfill_event_rows(txn): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates = [(row[0], row[1:]) for row in txn] + new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = [ + cast(Tuple[int, Tuple[str, str, str, str, str, str]], (row[0], row[1:])) + for row in txn + ] limited = False if len(new_event_updates) == limit: @@ -1423,7 +1474,10 @@ def get_all_new_backfill_event_rows(txn): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - new_event_updates.extend((row[0], row[1:]) for row in txn) + new_event_updates.extend( + cast(Tuple[int, Tuple[str, str, str, str, str, str]], (row[0], row[1:])) + for row in txn + ) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1437,7 +1491,7 @@ def get_all_new_backfill_event_rows(txn): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple], int, bool]: + ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1457,7 +1511,9 @@ async def get_all_updated_current_state_deltas( * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn(txn): + def get_all_updated_current_state_deltas_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1466,21 +1522,23 @@ def get_all_updated_current_state_deltas_txn(txn): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) - def get_deltas_for_stream_id_txn(txn, stream_id): + def get_deltas_for_stream_id_txn( + txn: LoggingTransaction, stream_id: int + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple] = await self.db_pool.runInteraction( + rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1509,14 +1567,14 @@ def get_deltas_for_stream_id_txn(txn, stream_id): return rows, to_token, True - async def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1: str, event_id2: str) -> bool: """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id): + async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1539,7 +1597,9 @@ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: None otherwise. """ - def get_next_event_to_expire_txn(txn): + def get_next_event_to_expire_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, int]]: txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1547,7 +1607,7 @@ def get_next_event_to_expire_txn(txn): """ ) - return txn.fetchone() + return cast(Optional[Tuple[str, int]], txn.fetchone()) return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1611,10 +1671,10 @@ async def get_already_persisted_events( return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self): + async def _cleanup_old_transaction_ids(self) -> None: """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn): + def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 0a6e4795ee92..596ba5a0c9f4 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -17,6 +17,7 @@ from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request @@ -193,7 +194,10 @@ def test_vector_clock_token(self): # # Worker2's event stream position will not advance until we call # __aexit__ again. - actx = worker_hs2.get_datastore()._stream_id_gen.get_next() + worker_store2 = worker_hs2.get_datastore() + assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) + + actx = worker_store2._stream_id_gen.get_next() self.get_success(actx.__aenter__()) response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) From 9bac7985073545a2f7b91d6fe289e3993379d6d9 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 23 Nov 2021 11:33:50 +0000 Subject: [PATCH 10/14] Add newsfile --- changelog.d/11411.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/11411.misc diff --git a/changelog.d/11411.misc b/changelog.d/11411.misc new file mode 100644 index 000000000000..86594a332db9 --- /dev/null +++ b/changelog.d/11411.misc @@ -0,0 +1 @@ +Add type hints to storage classes. From b40e5c4394de06730fe194162d91b2f0b4f503bd Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 24 Nov 2021 17:41:03 +0000 Subject: [PATCH 11/14] Address PR feedback --- synapse/replication/tcp/streams/events.py | 10 ++--- .../storage/databases/main/events_worker.py | 38 +++++++++++-------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index feb5cb12b516..a390cfcb74d5 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple, Type import attr @@ -157,9 +157,7 @@ async def _update_function( # now we fetch up to that many rows from the events table - event_rows: List[ - Tuple[int, str, str, str, str, str, str, str, str] - ] = await self._store.get_all_new_forward_event_rows( + event_rows = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count ) @@ -193,9 +191,7 @@ async def _update_function( # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows: List[ - Tuple[int, str, str, str, str, str, str, str, str] - ] = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d6628ce2c840..b64bd4efaa9b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -311,10 +311,10 @@ async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[False] = False, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[False] = ..., + check_room_id: Optional[str] = ..., ) -> EventBase: ... @@ -323,10 +323,10 @@ async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[True] = True, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[True] = ..., + check_room_id: Optional[str] = ..., ) -> Optional[EventBase]: ... @@ -1448,10 +1448,15 @@ def get_all_new_backfill_event_rows( " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = [ - cast(Tuple[int, Tuple[str, str, str, str, str, str]], (row[0], row[1:])) - for row in txn - ] + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] + row: Tuple[int, str, str, str, str, str, str] + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) limited = False if len(new_event_updates) == limit: @@ -1474,10 +1479,11 @@ def get_all_new_backfill_event_rows( " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - new_event_updates.extend( - cast(Tuple[int, Tuple[str, str, str, str, str, str]], (row[0], row[1:])) - for row in txn - ) + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] From 50742a0a7da783ea2e0933bd9fc432e6267a932a Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 25 Nov 2021 20:41:55 +0000 Subject: [PATCH 12/14] Improve IdGenerator docstrings --- .../slave/storage/_slaved_id_tracker.py | 8 +++++ synapse/storage/util/id_generators.py | 31 ++++++++++++++----- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 6191bcede3eb..fa132d10b414 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -18,6 +18,14 @@ class SlavedIdTracker(AbstractStreamIdTracker): + """Tracks the "current" stream ID of a stream with a single writer. + + See `AbstractStreamIdTracker` for more details. + + Note that this class does not work correctly when there are multiple + writers. + """ + def __init__( self, db_conn: LoggingDatabaseConnection, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 98a6d7120e7b..69d1005f38a9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -90,7 +90,16 @@ def _load_current_id( class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks a stream that can have multiple writers.""" + """Tracks the "current" stream ID of a stream that may have multiple writers. + + Stream IDs are monotonically increasing or decreasing integers representing write + transactions. The "current" stream ID is the stream ID such that all transactions + with equal or smaller stream IDs have completed. Since transactions may complete out + of order, this is not the same the stream ID of the last completed transaction. + + Completed transactions include both committed transactions and transactions that + have been rolled back. + """ @abc.abstractmethod def advance(self, instance_name: str, new_id: int) -> None: @@ -119,6 +128,14 @@ def get_current_token_for_writer(self, instance_name: str) -> int: class AbstractStreamIdGenerator(AbstractStreamIdTracker): + """Generates stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. + + See `AbstractStreamIdTracker` for more details. + """ + @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ @@ -139,12 +156,10 @@ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: class StreamIdGenerator(AbstractStreamIdGenerator): - """Used to generate new stream ids when persisting events while keeping - track of which transactions have been completed. + """Generates and tracks stream IDs for a stream with a single writer. - This allows us to get the "current" stream id, i.e. the stream id such that - all ids less than or equal to it have completed. This handles the fact that - persistence of events can complete out of order. + This class must only be used when the current Synapse process is the sole + writer for a stream. Args: db_conn(connection): A database connection to use to fetch the @@ -189,6 +204,8 @@ def __init__( self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: + # `StreamIdGenerator` should only be used when there is a single writer, + # so replication should never happen. raise Exception("Replication is not supported by StreamIdGenerator") def get_next(self) -> AsyncContextManager[int]: @@ -243,7 +260,7 @@ def get_current_token_for_writer(self, instance_name: str) -> int: class MultiWriterIdGenerator(AbstractStreamIdGenerator): - """An ID generator that tracks a stream that can have multiple writers. + """Generates and tracks stream IDs for a stream with multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other writers will only get updated when `advance` is called (by replication). From ad514deff99356076fff1d5b7c964fff553cb37c Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 26 Nov 2021 14:16:54 +0000 Subject: [PATCH 13/14] Fix typo in docstring --- synapse/storage/util/id_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 69d1005f38a9..4ff3013908a7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -95,7 +95,7 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta): Stream IDs are monotonically increasing or decreasing integers representing write transactions. The "current" stream ID is the stream ID such that all transactions with equal or smaller stream IDs have completed. Since transactions may complete out - of order, this is not the same the stream ID of the last completed transaction. + of order, this is not the same as the stream ID of the last completed transaction. Completed transactions include both committed transactions and transactions that have been rolled back. From 87fe9c6d4cc78113a2e03f8006fe94433914661b Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 26 Nov 2021 15:21:11 +0000 Subject: [PATCH 14/14] Fix merge --- synapse/storage/databases/main/events_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 28ff9c8b6ba3..4cefc0a07ea3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -851,7 +851,7 @@ async def _fetch_thread(self) -> None: for _, deferred in event_fetches_to_fail: deferred.errback(exc) - def _fetch_loop(self, conn: Connection) -> None: + def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """