diff --git a/changelog.d/15751.misc b/changelog.d/15751.misc new file mode 100644 index 000000000000..e0ecea6c2fdb --- /dev/null +++ b/changelog.d/15751.misc @@ -0,0 +1 @@ +Add foreign key constraint to `event_forward_extremities`. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index a803ada8ad06..e126a2e0c573 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -61,6 +61,7 @@ from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_push_actions import EventPushActionsStore from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, @@ -239,6 +240,7 @@ class Store( PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, RelationsWorkerStore, + EventFederationWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index edc97a9d6105..5dce0a01599d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging -from enum import IntEnum +from enum import Enum, IntEnum from types import TracebackType from typing import ( TYPE_CHECKING, @@ -24,12 +25,16 @@ Iterable, List, Optional, + Sequence, + Tuple, Type, ) import attr +from pydantic import BaseModel from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder @@ -48,6 +53,78 @@ MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] +class Constraint(metaclass=abc.ABCMeta): + """Base class representing different constraints. + + Used by `register_background_validate_constraint_and_delete_rows`. + """ + + @abc.abstractmethod + def make_check_clause(self, table: str) -> str: + """Returns an SQL expression that checks the row passes the constraint.""" + pass + + @abc.abstractmethod + def make_constraint_clause_postgres(self) -> str: + """Returns an SQL clause for creating the constraint. + + Only used on Postgres DBs + """ + pass + + +@attr.s(auto_attribs=True) +class ForeignKeyConstraint(Constraint): + """A foreign key constraint. + + Attributes: + referenced_table: The "parent" table name. + columns: The list of mappings of columns from table to referenced table + """ + + referenced_table: str + columns: Sequence[Tuple[str, str]] + + def make_check_clause(self, table: str) -> str: + join_clause = " AND ".join( + f"{col1} = {table}.{col2}" for col1, col2 in self.columns + ) + return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})" + + def make_constraint_clause_postgres(self) -> str: + column1_list = ", ".join(col1 for col1, col2 in self.columns) + column2_list = ", ".join(col2 for col1, col2 in self.columns) + return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})" + + +@attr.s(auto_attribs=True) +class NotNullConstraint(Constraint): + """A NOT NULL column constraint""" + + column: str + + def make_check_clause(self, table: str) -> str: + return f"{self.column} IS NOT NULL" + + def make_constraint_clause_postgres(self) -> str: + return f"CHECK ({self.column} IS NOT NULL)" + + +class ValidateConstraintProgress(BaseModel): + """The format of the progress JSON for validate constraint background + updates. + + Used by `register_background_validate_constraint_and_delete_rows`. + """ + + class State(str, Enum): + check = "check" + validate = "validate" + + state: State = State.validate + lower_bound: Sequence[Any] = () + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _BackgroundUpdateHandler: """A handler for a given background update. @@ -740,6 +817,179 @@ def create_index_sqlite(conn: Connection) -> None: logger.info("Adding index %s to %s", index_name, table) await self.db_pool.runWithConnection(runner) + def register_background_validate_constraint_and_delete_rows( + self, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + ) -> None: + """Helper for store classes to do a background validate constraint, and + delete rows that do not pass the constraint check. + + Note: This deletes rows that don't match the constraint. This may not be + appropriate in all situations, and so the suitability of using this + method should be considered on a case-by-case basis. + + This only applies on PostgreSQL. + + For SQLite the table gets recreated as part of the schema delta and the + data is copied over synchronously (or whatever the correct way to + describe it as). + + Args: + update_name: The name of the background update. + table: The table with the invalid constraint. + constraint_name: The name of the constraint + constraint: A `Constraint` object matching the type of constraint. + unique_columns: A sequence of columns that form a unique constraint + on the table. Used to iterate over the table. + """ + + assert isinstance( + self.db_pool.engine, engines.PostgresEngine + ), "validate constraint background update registered for non-Postres database" + + async def updater(progress: JsonDict, batch_size: int) -> int: + return await self.validate_constraint_and_delete_in_background( + update_name=update_name, + table=table, + constraint_name=constraint_name, + constraint=constraint, + unique_columns=unique_columns, + progress=progress, + batch_size=batch_size, + ) + + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) + + async def validate_constraint_and_delete_in_background( + self, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + progress: JsonDict, + batch_size: int, + ) -> int: + """Validates a table constraint that has been marked as `NOT VALID`, + deleting rows that don't pass the constraint check. + + This will delete rows that do not meet the validation check. + + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + """ + + # We validate the constraint by: + # 1. Trying to validate the constraint as is. If this succeeds then + # we're done. + # 2. Otherwise, we manually scan the table to remove rows that don't + # match the constraint. + # 3. We try re-validating the constraint. + + parsed_progress = ValidateConstraintProgress.parse_obj(progress) + + if parsed_progress.state == ValidateConstraintProgress.State.check: + return_columns = ", ".join(unique_columns) + order_columns = ", ".join(unique_columns) + + where_clause = "" + args: List[Any] = [] + if parsed_progress.lower_bound: + where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})""" + args.extend(parsed_progress.lower_bound) + + args.append(batch_size) + + sql = f""" + SELECT + {return_columns}, + {constraint.make_check_clause(table)} AS check + FROM {table} + {where_clause} + ORDER BY {order_columns} + LIMIT ? + """ + + def validate_constraint_in_background_check( + txn: "LoggingTransaction", + ) -> None: + txn.execute(sql, args) + rows = txn.fetchall() + + new_progress = parsed_progress.copy() + + if not rows: + new_progress.state = ValidateConstraintProgress.State.validate + self._background_update_progress_txn( + txn, update_name, new_progress.dict() + ) + return + + new_progress.lower_bound = rows[-1][:-1] + + to_delete = [row[:-1] for row in rows if not row[-1]] + + if to_delete: + logger.warning( + "Deleting %d rows that do not pass new constraint", + len(to_delete), + ) + + self.db_pool.simple_delete_many_batch_txn( + txn, table=table, keys=unique_columns, values=to_delete + ) + + self._background_update_progress_txn( + txn, update_name, new_progress.dict() + ) + + await self.db_pool.runInteraction( + "validate_constraint_in_background_check", + validate_constraint_in_background_check, + ) + + return batch_size + + elif parsed_progress.state == ValidateConstraintProgress.State.validate: + sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}" + + def validate_constraint_in_background_validate( + txn: "LoggingTransaction", + ) -> None: + txn.execute(sql) + + try: + await self.db_pool.runInteraction( + "validate_constraint_in_background_validate", + validate_constraint_in_background_validate, + ) + + await self._end_background_update(update_name) + except self.db_pool.engine.module.IntegrityError as e: + # If we get an integrity error here, then we go back and recheck the table. + logger.warning("Integrity error when validating constraint: %s", e) + await self._background_update_progress( + update_name, + ValidateConstraintProgress( + state=ValidateConstraintProgress.State.check + ).dict(), + ) + + return batch_size + else: + raise Exception( + f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background" + ) + async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. @@ -795,3 +1045,86 @@ def _background_update_progress_txn( keyvalues={"update_name": update_name}, updatevalues={"progress_json": progress_json}, ) + + +def run_validate_constraint_and_delete_rows_schema_delta( + txn: "LoggingTransaction", + ordering: int, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + sqlite_table_name: str, + sqlite_table_schema: str, +) -> None: + """Runs a schema delta to add a constraint to the table. This should be run + in a schema delta file. + + For PostgreSQL the constraint is added and validated in the background. + + For SQLite the table is recreated and data copied across immediately. This + is done by the caller passing in a script to create the new table. Note that + table indexes and triggers are copied over automatically. + + There must be a corresponding call to + `register_background_validate_constraint_and_delete_rows` to register the + background update in one of the data store classes. + + Attributes: + txn ordering, update_name: For adding a row to background_updates table. + table: The table to add constraint to. constraint_name: The name of the + new constraint constraint: A `Constraint` object describing the + constraint sqlite_table_name: For SQLite the name of the empty copy of + table sqlite_table_schema: A SQL script for creating the above table. + """ + + if isinstance(txn.database_engine, PostgresEngine): + # For postgres we can just add the constraint and mark it as NOT VALID, + # and then insert a background update to go and check the validity in + # the background. + txn.execute( + f""" + ALTER TABLE {table} + ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()} + NOT VALID + """ + ) + + txn.execute( + "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')", + (ordering, update_name), + ) + else: + # For SQLite, we: + # 1. fetch all indexes/triggers/etc related to the table + # 2. create an empty copy of the table + # 3. copy across the rows (that satisfy the check) + # 4. replace the old table with the new able. + # 5. add back all the indexes/triggers/etc + + # Fetch the indexes/triggers/etc. Note that `sql` column being null is + # due to indexes being auto created based on the class definition (e.g. + # PRIMARY KEY), and so don't need to be recreated. + txn.execute( + """ + SELECT sql FROM sqlite_master + WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL + """, + (table,), + ) + extras = [row[0] for row in txn] + + txn.execute(sqlite_table_schema) + + sql = f""" + INSERT INTO {sqlite_table_name} SELECT * FROM {table} + WHERE {constraint.make_check_clause(table)} + """ + + txn.execute(sql) + + txn.execute(f"DROP TABLE {table}") + txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}") + + for extra in extras: + txn.execute(extra) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7e49ae11bc0e..a1c8fb0f46a4 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2313,6 +2313,43 @@ def simple_delete_many_txn( return txn.rowcount + @staticmethod + def simple_delete_many_batch_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes a DELETE query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % ( + table, + ", ".join(k for k in keys), + ) + + txn.execute_values(sql, values, fetch=False) + else: + sql = "DELETE FROM %s WHERE (%s) = (%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute_batch(sql, values) + def get_cache_dict( self, db_conn: LoggingDatabaseConnection, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 8b6e3c1dc734..dabe603c8cba 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import tag_args, trace from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.background_updates import ForeignKeyConstraint from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -140,6 +141,15 @@ def __init__( self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) + if isinstance(self.database_engine, PostgresEngine): + self.db_pool.updates.register_background_validate_constraint_and_delete_rows( + update_name="event_forward_extremities_event_id_foreign_key_constraint_update", + table="event_forward_extremities", + constraint_name="event_forward_extremities_event_id", + constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]), + unique_columns=("event_id", "room_id"), + ) + async def get_auth_chain( self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5c9db7554ef4..2b83a694269f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -415,12 +415,6 @@ def _persist_events_txn( backfilled=False, ) - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremities, - max_stream_order=max_stream_order, - ) - # Ensure that we don't have the same event twice. events_and_contexts = self._filter_events_and_contexts_for_duplicates( events_and_contexts @@ -439,6 +433,12 @@ def _persist_events_txn( self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremities, + max_stream_order=max_stream_order, + ) + self._persist_transaction_ids_txn(txn, events_and_contexts) # Insert into event_to_state_groups. diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py new file mode 100644 index 000000000000..f12e2a8f3ee9 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py @@ -0,0 +1,51 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds foreign key constraint to `event_forward_extremities` table. +""" +from synapse.storage.background_updates import ( + ForeignKeyConstraint, + run_validate_constraint_and_delete_rows_schema_delta, +) +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + +FORWARD_EXTREMITIES_TABLE_SCHEMA = """ + CREATE TABLE event_forward_extremities2( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id), + CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id) + ) +""" + + +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: + run_validate_constraint_and_delete_rows_schema_delta( + cur, + ordering=7803, + update_name="event_forward_extremities_event_id_foreign_key_constraint_update", + table="event_forward_extremities", + constraint_name="event_forward_extremities_event_id", + constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]), + sqlite_table_name="event_forward_extremities2", + sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA, + ) + + # We can't add a similar constraint to `event_backward_extremities` as the + # events in there don't exist in the `events` table and `event_edges` + # doesn't have a unique constraint on `prev_event_id` (so we can't make a + # foreign key point to it). diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index fd619b64d4dd..6ca546f3f76a 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -20,7 +20,14 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.background_updates import BackgroundUpdater +from synapse.storage.background_updates import ( + BackgroundUpdater, + ForeignKeyConstraint, + NotNullConstraint, + run_validate_constraint_and_delete_rows_schema_delta, +) +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util import Clock @@ -404,3 +411,221 @@ def test_controller(self) -> None: self.pump() self._update_ctx_manager.__aexit__.assert_called() self.get_success(do_update_d) + + +class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): + """Tests the validate contraint and delete background handlers.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates + # the base test class should have run the real bg updates for us + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) + + self.store = self.hs.get_datastores().main + + def test_not_null_constraint(self) -> None: + # Create the initial tables, where we have some invalid data. + """Tests adding a not null constraint.""" + table_sql = """ + CREATE TABLE test_constraint( + a INT PRIMARY KEY, + b INT + ); + """ + self.get_success( + self.store.db_pool.execute( + "test_not_null_constraint", lambda _: None, table_sql + ) + ) + + # We add an index so that we can check that its correctly recreated when + # using SQLite. + index_sql = "CREATE INDEX test_index ON test_constraint(a)" + self.get_success( + self.store.db_pool.execute( + "test_not_null_constraint", lambda _: None, index_sql + ) + ) + + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) + ) + + # Now lets do the migration + + table2_sqlite = """ + CREATE TABLE test_constraint2( + a INT PRIMARY KEY, + b INT, + CONSTRAINT test_constraint_name CHECK (b is NOT NULL) + ); + """ + + def delta(txn: LoggingTransaction) -> None: + run_validate_constraint_and_delete_rows_schema_delta( + txn, + ordering=1000, + update_name="test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=NotNullConstraint("b"), + sqlite_table_name="test_constraint2", + sqlite_table_schema=table2_sqlite, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_not_null_constraint", + delta, + ) + ) + + if isinstance(self.store.database_engine, PostgresEngine): + # Postgres uses a background update + self.updates.register_background_validate_constraint_and_delete_rows( + "test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=NotNullConstraint("b"), + unique_columns=["a"], + ) + + # Tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Check the correct values are in the new table. + rows = self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ) + + self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + + # And check that invalid rows get correctly rejected. + self.get_failure( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}), + exc=self.store.database_engine.module.IntegrityError, + ) + + # Check the index is still there for SQLite. + if isinstance(self.store.database_engine, Sqlite3Engine): + # Ensure the index exists in the schema. + self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="sqlite_master", + keyvalues={"tbl_name": "test_constraint"}, + retcol="name", + ) + ) + + def test_foreign_constraint(self) -> None: + """Tests adding a not foreign key constraint.""" + + # Create the initial tables, where we have some invalid data. + base_sql = """ + CREATE TABLE base_table( + b INT PRIMARY KEY + ); + """ + + table_sql = """ + CREATE TABLE test_constraint( + a INT PRIMARY KEY, + b INT NOT NULL + ); + """ + self.get_success( + self.store.db_pool.execute( + "test_foreign_key_constraint", lambda _: None, base_sql + ) + ) + self.get_success( + self.store.db_pool.execute( + "test_foreign_key_constraint", lambda _: None, table_sql + ) + ) + + self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1})) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}) + ) + self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3})) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) + ) + + table2_sqlite = """ + CREATE TABLE test_constraint2( + a INT PRIMARY KEY, + b INT NOT NULL, + CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b) + ); + """ + + def delta(txn: LoggingTransaction) -> None: + run_validate_constraint_and_delete_rows_schema_delta( + txn, + ordering=1000, + update_name="test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=ForeignKeyConstraint("base_table", [("b", "b")]), + sqlite_table_name="test_constraint2", + sqlite_table_schema=table2_sqlite, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", + delta, + ) + ) + + if isinstance(self.store.database_engine, PostgresEngine): + # Postgres uses a background update + self.updates.register_background_validate_constraint_and_delete_rows( + "test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=ForeignKeyConstraint("base_table", [("b", "b")]), + unique_columns=["a"], + ) + + # Tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Check the correct values are in the new table. + rows = self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ) + self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + + # And check that invalid rows get correctly rejected. + self.get_failure( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}), + exc=self.store.database_engine.module.IntegrityError, + ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0f3b0744f184..9c151a5e62d1 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, @@ -98,8 +99,32 @@ def test_get_rooms_with_many_extremities(self) -> None: room2 = "#room2" room3 = "#room3" - def insert_event(txn: Cursor, i: int, room_id: str) -> None: + def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None: event_id = "$event_%i:local" % i + + # We need to insert into events table to get around the foreign key constraint. + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "instance_name": "master", + "stream_ordering": self.store._stream_id_gen.get_next_txn(txn), + "topological_ordering": 1, + "depth": 1, + "event_id": event_id, + "room_id": room_id, + "type": EventTypes.Message, + "processed": True, + "outlier": False, + "origin_server_ts": 0, + "received_ts": 0, + "sender": "@user:local", + "contains_url": False, + "state_key": None, + "rejection_reason": None, + }, + ) + txn.execute( ( "INSERT INTO event_forward_extremities (room_id, event_id) " @@ -113,10 +138,14 @@ def insert_event(txn: Cursor, i: int, room_id: str) -> None: self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db_pool.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction( + "insert", insert_event, i + 100, room2 + ) ) self.get_success( - self.store.db_pool.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction( + "insert", insert_event, i + 200, room3 + ) ) # Test simple case