Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add foreign key constraint to event_forward_extremities. #15751

Merged
merged 13 commits into from
Jul 5, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/15751.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add foreign key constraint to `event_forward_extremities`.
2 changes: 2 additions & 0 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import EventPushActionsStore
from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore,
Expand Down Expand Up @@ -239,6 +240,7 @@ class Store(
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
EventFederationWorkerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
Expand Down
316 changes: 315 additions & 1 deletion synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from enum import IntEnum
from enum import Enum, IntEnum
from io import StringIO
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand All @@ -24,12 +26,16 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
)

import attr
from pydantic import BaseModel

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
Expand All @@ -48,6 +54,78 @@
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]


class Constraint(metaclass=abc.ABCMeta):
"""Base class representing different constraints.

Used by `register_background_validate_constraint_and_delete_rows`.
"""

@abc.abstractmethod
def make_check_clause(self, table: str) -> str:
"""Returns an SQL expression that checks the row passes the constraint."""
pass

@abc.abstractmethod
def make_constraint_clause_postgres(self) -> str:
"""Returns an SQL clause for creating the constraint.

Only used on Postgres DBs
"""
pass


@attr.s(auto_attribs=True)
class ForeignKeyConstraint(Constraint):
"""A foreign key constraint.

Attributes:
referenced_table: The "parent" table name.
columns: The list of mappings of columns from table to referenced table
"""

referenced_table: str
columns: Sequence[Tuple[str, str]]

def make_check_clause(self, table: str) -> str:
join_clause = " AND ".join(
f"{col1} = {table}.{col2}" for col1, col2 in self.columns
)
return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"

def make_constraint_clause_postgres(self) -> str:
column1_list = ", ".join(col1 for col1, col2 in self.columns)
column2_list = ", ".join(col2 for col1, col2 in self.columns)
return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"


@attr.s(auto_attribs=True)
class NotNullConstraint(Constraint):
"""A NOT NULL column constraint"""

column: str

def make_check_clause(self, table: str) -> str:
return f"{self.column} IS NOT NULL"

def make_constraint_clause_postgres(self) -> str:
return f"CHECK ({self.column} IS NOT NULL)"


class ValidateConstraintProgress(BaseModel):
"""The format of the progress JSON for validate constraint background
updates.

Used by `register_background_validate_constraint_and_delete_rows`.
"""

class State(str, Enum):
check = "check"
validate = "validate"

state: State = State.validate
lower_bound: Sequence[Any] = ()
Comment on lines +124 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: are these types being used for validation or is it "just" a dataclass? (If the former, recommend a short unit test to sanity check that the validation behaves as you expect)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ermh, it's mostly just used as a data class I think? So long as the pydantic models correctly round trip it should be fine?



@attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler:
"""A handler for a given background update.
Expand Down Expand Up @@ -740,6 +818,169 @@ def create_index_sqlite(conn: Connection) -> None:
logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner)

def register_background_validate_constraint_and_delete_rows(
self,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
) -> None:
"""Helper for store classes to do a background validate constraint, and
delete rows that do not pass the constraint check.

MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
This only applies on PostgreSQL.
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved

Args:
update_name: The name of the background update.
table: The table with the invalid constraint.
constraint_name: The name of the constraint
constraint: A `Constraint` object matching the type of constraint.
unique_columns: A sequence of columns that form a unique constraint
on the table. Used to iterate over the table.
"""

assert isinstance(
self.db_pool.engine, engines.PostgresEngine
), "validate constraint background update registered for non-Postres database"

async def updater(progress: JsonDict, batch_size: int) -> int:
return await self.validate_constraint_and_delete_in_background(
update_name=update_name,
table=table,
constraint_name=constraint_name,
constraint=constraint,
unique_columns=unique_columns,
progress=progress,
batch_size=batch_size,
)

self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
updater, oneshot=True
)

async def validate_constraint_and_delete_in_background(
self,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
progress: JsonDict,
batch_size: int,
) -> int:
"""Validates a table constraint that has been marked as `NOT VALID`,
deleting rows that don't pass the constraint check.

This will delete rows that do not meet the validation check.

update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
"""

# We validate the constraint by:
# 1. Trying to validate the constraint as is. If this succeeds then
# we're done.
# 2. Otherwise, we manually scan the table to remove rows that don't
# match the constraint.
# 3. We try re-validating the constraint.

parsed_progress = ValidateConstraintProgress.parse_obj(progress)

if parsed_progress.state == ValidateConstraintProgress.State.check:
return_columns = ", ".join(unique_columns)
order_columns = ", ".join(unique_columns)

where_clause = ""
args: List[Any] = []
if parsed_progress.lower_bound:
where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
args.extend(parsed_progress.lower_bound)

args.append(batch_size)

sql = f"""
SELECT
{return_columns},
{constraint.make_check_clause(table)} AS check
FROM {table}
{where_clause}
ORDER BY {order_columns}
LIMIT ?
"""

def validate_constraint_in_background_check(
txn: "LoggingTransaction",
) -> None:
txn.execute(sql, args)
rows = txn.fetchall()

new_progress = parsed_progress.copy()

if not rows:
new_progress.state = ValidateConstraintProgress.State.validate
self._background_update_progress_txn(
txn, update_name, new_progress.dict()
)
return

new_progress.lower_bound = rows[-1][:-1]

to_delete = [row[:-1] for row in rows if not row[-1]]

if to_delete:
logger.warning(
"Deleting %d rows that do not pass new constraint",
len(to_delete),
)

self.db_pool.simple_delete_many_batch_txn(
txn, table=table, keys=unique_columns, values=to_delete
)

self._background_update_progress_txn(
txn, update_name, new_progress.dict()
)

await self.db_pool.runInteraction(
"validate_constraint_in_background_check",
validate_constraint_in_background_check,
)

return batch_size

elif parsed_progress.state == ValidateConstraintProgress.State.validate:
sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"

def validate_constraint_in_background_validate(
txn: "LoggingTransaction",
) -> None:
txn.execute(sql)

try:
await self.db_pool.runInteraction(
"validate_constraint_in_background_validate",
validate_constraint_in_background_validate,
)

await self._end_background_update(update_name)
except self.db_pool.engine.module.IntegrityError as e:
# If we get an integrity error here, then we go back and recheck the table.
logger.warning("Integrity error when validating constraint: %s", e)
await self._background_update_progress(
update_name,
ValidateConstraintProgress(
state=ValidateConstraintProgress.State.check
).dict(),
)

return batch_size
else:
raise Exception(f"Unrecognized state '{parsed_progress.state}'")
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.

Expand Down Expand Up @@ -795,3 +1036,76 @@ def _background_update_progress_txn(
keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json},
)


def run_validate_constraint_and_delete_rows_schema_delta(
txn: "LoggingTransaction",
ordering: int,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
sqlite_table_name: str,
sqlite_table_schema: str,
sqlite_post_schema: Optional[str],
) -> None:
"""Runs a schema delta to add a constraint to the table. This should be run
in a schema delta file.

For PostgreSQL the constraint is added and validated in the background.

For SQLite the table is recreated and data copied across immediately. This
is done by the caller passing in a script to create the new table.

There must be a corresponding call to
`register_background_validate_constraint_and_delete_rows`.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Attributes:
txn ordering, update_name: For adding a row to background_updates table.
table: The table to add constraint to. constraint_name: The name of the
new constraint constraint: A `Constraint` object describing the
constraint sqlite_table_name: For SQLite the name of the empty copy of
table sqlite_table_schema: A SQL script for creating the above table.
sqlite_post_schema: A SQL script run after migration, to add back
indices and the like.
"""

if isinstance(txn.database_engine, PostgresEngine):
# For postgres we can just add the constraint and mark it as NOT VALID,
# and then insert a background update to go and check the validity in
# the background.
txn.execute(
f"""
ALTER TABLE {table}
ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
NOT VALID
"""
)

txn.execute(
"INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
(ordering, update_name),
)
else:
# For SQLite, we:
# 1. create an empty copy of the table
# 2. copy across the rows (that satisfy the check)
# 3. replace the old table with the new able.

# We import this here to avoid circular imports.
from synapse.storage.prepare_database import execute_statements_from_stream

execute_statements_from_stream(txn, StringIO(sqlite_table_schema))

sql = f"""
INSERT INTO {sqlite_table_name} SELECT * FROM {table}
WHERE {constraint.make_check_clause(table)}
"""

txn.execute(sql)

txn.execute(f"DROP TABLE {table}")
txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")

if sqlite_post_schema:
execute_statements_from_stream(txn, StringIO(sqlite_post_schema))
Loading