Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor to ensure we call check_consistency (#9470)
Browse files Browse the repository at this point in the history
The idea here is to stop people forgetting to call `check_consistency`. Folks can still just pass in `None` to the new args in `build_sequence_generator`, but hopefully they won't.
  • Loading branch information
erikjohnston authored Feb 24, 2021
1 parent 713145d commit 0b5c967
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog.d/9470.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix missing startup checks for the consistency of certain PostgreSQL sequences.
16 changes: 4 additions & 12 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection

# python 3 does not have a maximum int value
Expand Down Expand Up @@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0

def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
self,
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
):
self.hs = hs
self._clock = hs.get_clock()
Expand Down Expand Up @@ -420,16 +422,6 @@ def __init__(
self._check_safe_to_upsert,
)

# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]

self.event_chain_id_gen = build_sequence_generator(
engine, get_chain_id_txn, "event_auth_chain_id"
)

def is_running(self) -> bool:
"""Is the database pool currently running"""
return self._db_pool.running
Expand Down
13 changes: 6 additions & 7 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
Expand Down Expand Up @@ -114,12 +115,6 @@ def __init__(
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator

# The consistency of this cannot be checked when the ID generator is
# created since the database might not yet be up-to-date.
self.db_pool.event_chain_id_gen.check_consistency(
db_conn, "event_auth_chains", "chain_id" # type: ignore
)

# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
Expand Down Expand Up @@ -485,6 +480,7 @@ def _persist_event_auth_chain_txn(
self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand All @@ -495,6 +491,7 @@ def _add_chain_cover_index(
cls,
txn,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
Expand Down Expand Up @@ -641,6 +638,7 @@ def _add_chain_cover_index(
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand Down Expand Up @@ -779,6 +777,7 @@ def _add_chain_cover_index(
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
Expand Down Expand Up @@ -891,7 +890,7 @@ def _allocate_chain_ids(
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]

# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)

Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ def _calculate_chain_cover_txn(
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand Down
16 changes: 16 additions & 0 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
Expand Down Expand Up @@ -156,6 +157,21 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._event_fetch_list = []
self._event_fetch_ongoing = 0

# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]

self.event_chain_id_gen = build_sequence_generator(
db_conn,
database.engine,
get_chain_id_txn,
"event_auth_chain_id",
table="event_auth_chains",
id_column="chain_id",
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
Expand Down
19 changes: 16 additions & 3 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
Expand Down Expand Up @@ -70,7 +70,12 @@ def _default_token_owner(self):


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.config = hs.config
Expand All @@ -79,9 +84,12 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
db_conn,
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
table=None,
id_column=None,
)

self._account_validity = hs.config.account_validity
Expand Down Expand Up @@ -1036,7 +1044,12 @@ async def update_access_token_last_validated(self, token_id: int) -> None:


class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._clock = hs.get_clock()
Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def get_max_state_group_txn(txn: Cursor):
return txn.fetchone()[0]

self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)
self._state_group_seq_gen.check_consistency(
db_conn, table="state_groups", id_column="id"
db_conn,
self.database_engine,
get_max_state_group_txn,
"state_group_id_seq",
table="state_groups",
id_column="id",
)

@cached(max_entries=10000, iterable=True)
Expand Down
24 changes: 22 additions & 2 deletions synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,14 @@ def check_consistency(


def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
table: Optional[str],
id_column: Optional[str],
stream_name: Optional[str] = None,
positive: bool = True,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available
Expand All @@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
table, id_column, stream_name, positive: If set then `check_consistency`
is called on the created sequence. See docstring for
`check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
return PostgresSequenceGenerator(sequence_name)
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else:
return LocalSequenceGenerator(get_first_callback)
seq = LocalSequenceGenerator(get_first_callback)

if table:
assert id_column
seq.check_consistency(
db_conn=db_conn,
table=table,
id_column=id_column,
stream_name=stream_name,
positive=positive,
)

return seq

0 comments on commit 0b5c967

Please sign in to comment.