Skip to content

Commit

Permalink
Replaces all usages of StreamIdGenerator with `MultiWriterIdGenerat…
Browse files Browse the repository at this point in the history
…or` (element-hq#17229)

Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer.
  • Loading branch information
erikjohnston authored and H-Shay committed May 31, 2024
1 parent 8e0e1d5 commit 641b6b2
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 363 deletions.
1 change: 1 addition & 0 deletions changelog.d/17229.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
71 changes: 61 additions & 10 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,22 +777,74 @@ def alter_table(txn: LoggingTransaction) -> None:
await self._setup_events_stream_seqs()
await self._setup_sequence(
"un_partial_stated_event_stream_sequence",
("un_partial_stated_event_stream",),
[("un_partial_stated_event_stream", "stream_id")],
)
await self._setup_sequence(
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
"device_inbox_sequence",
[
("device_inbox", "stream_id"),
("device_federation_outbox", "stream_id"),
],
)
await self._setup_sequence(
"account_data_sequence",
("room_account_data", "room_tags_revisions", "account_data"),
[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
("account_data", "stream_id"),
],
)
await self._setup_sequence(
"receipts_sequence",
[
("receipts_linearized", "stream_id"),
],
)
await self._setup_sequence(
"presence_stream_sequence",
[
("presence_stream", "stream_id"),
],
)
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
await self._setup_auth_chain_sequence()
await self._setup_sequence(
"application_services_txn_id_seq",
("application_services_txns",),
"txn_id",
[
(
"application_services_txns",
"txn_id",
)
],
)
await self._setup_sequence(
"device_lists_sequence",
[
("device_lists_stream", "stream_id"),
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"),
],
)
await self._setup_sequence(
"e2e_cross_signing_keys_sequence",
[
("e2e_cross_signing_keys", "stream_id"),
],
)
await self._setup_sequence(
"push_rules_stream_sequence",
[
("push_rules_stream", "stream_id"),
],
)
await self._setup_sequence(
"pushers_sequence",
[
("pushers", "id"),
("deleted_pushers", "stream_id"),
],
)

# Step 3. Get tables.
Expand Down Expand Up @@ -1101,12 +1153,11 @@ def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
async def _setup_sequence(
self,
sequence_name: str,
stream_id_tables: Iterable[str],
column_name: str = "stream_id",
stream_id_tables: Iterable[Tuple[str, str]],
) -> None:
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
for stream_id_table, column_name in stream_id_tables:
max_stream_id = cast(
int,
await self.sqlite_store.db_pool.simple_select_one_onecol(
Expand Down
54 changes: 30 additions & 24 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
Expand Down Expand Up @@ -99,19 +96,21 @@ def __init__(

# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._device_list_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"),
self._device_list_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="device_lists_stream",
instance_name=self._instance_name,
tables=[
("device_lists_stream", "instance_name", "stream_id"),
("user_signature_stream", "instance_name", "stream_id"),
("device_lists_outbound_pokes", "instance_name", "stream_id"),
("device_lists_changes_in_room", "instance_name", "stream_id"),
("device_lists_remote_pending", "instance_name", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
sequence_name="device_lists_sequence",
writers=["master"],
)

device_list_max = self._device_list_id_gen.get_current_token()
Expand Down Expand Up @@ -762,6 +761,7 @@ def _add_user_signature_change_txn(
"stream_id": stream_id,
"from_user_id": from_user_id,
"user_ids": json_encoder.encode(user_ids),
"instance_name": self._instance_name,
},
)

Expand Down Expand Up @@ -1582,6 +1582,8 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()

self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
Expand Down Expand Up @@ -1694,6 +1696,7 @@ def _txn(txn: LoggingTransaction) -> int:
"device_lists_outbound_pokes",
{
"stream_id": stream_id,
"instance_name": self._instance_name,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
Expand Down Expand Up @@ -1730,10 +1733,6 @@ def _txn(txn: LoggingTransaction) -> int:


class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
# (see DeviceWorkerStore.__init__)
_device_list_id_gen: AbstractStreamIdGenerator

def __init__(
self,
database: DatabasePool,
Expand Down Expand Up @@ -2092,9 +2091,9 @@ def _add_device_change_to_stream_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
keys=("stream_id", "user_id", "device_id"),
keys=("instance_name", "stream_id", "user_id", "device_id"),
values=[
(stream_id, user_id, device_id)
(self._instance_name, stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
Expand Down Expand Up @@ -2124,6 +2123,7 @@ def _add_device_outbound_poke_to_stream_txn(
values = [
(
destination,
self._instance_name,
next(stream_id_iterator),
user_id,
device_id,
Expand All @@ -2139,6 +2139,7 @@ def _add_device_outbound_poke_to_stream_txn(
table="device_lists_outbound_pokes",
keys=(
"destination",
"instance_name",
"stream_id",
"user_id",
"device_id",
Expand All @@ -2157,7 +2158,7 @@ def _add_device_outbound_poke_to_stream_txn(
device_id,
{
stream_id: destination
for (destination, stream_id, _, _, _, _, _) in values
for (destination, _, stream_id, _, _, _, _, _) in values
},
)

Expand Down Expand Up @@ -2210,6 +2211,7 @@ def _add_device_outbound_room_poke_txn(
"device_id",
"room_id",
"stream_id",
"instance_name",
"converted_to_destinations",
"opentracing_context",
),
Expand All @@ -2219,6 +2221,7 @@ def _add_device_outbound_room_poke_txn(
device_id,
room_id,
stream_id,
self._instance_name,
# We only need to calculate outbound pokes for local users
not self.hs.is_mine_id(user_id),
encoded_context,
Expand Down Expand Up @@ -2338,7 +2341,10 @@ async def add_remote_device_list_to_pending(
"user_id": user_id,
"device_id": device_id,
},
values={"stream_id": stream_id},
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
},
desc="add_remote_device_list_to_pending",
)

Expand Down
19 changes: 13 additions & 6 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand Down Expand Up @@ -1448,11 +1448,17 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self._cross_signing_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
self._cross_signing_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="e2e_cross_signing_keys",
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),
],
sequence_name="e2e_cross_signing_keys_sequence",
writers=["master"],
)

async def set_e2e_device_keys(
Expand Down Expand Up @@ -1627,6 +1633,7 @@ def _set_e2e_cross_signing_key_txn(
"keytype": key_type,
"keydata": json_encoder.encode(key),
"stream_id": stream_id,
"instance_name": self._instance_name,
},
)

Expand Down
24 changes: 14 additions & 10 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
Expand Down Expand Up @@ -126,7 +126,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""

_push_rules_stream_id_gen: StreamIdGenerator
_push_rules_stream_id_gen: MultiWriterIdGenerator

def __init__(
self,
Expand All @@ -140,14 +140,17 @@ def __init__(
hs.get_instance_name() in hs.config.worker.writers.push_rules
)

# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=self._is_push_writer,
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="push_rules_stream",
instance_name=self._instance_name,
tables=[
("push_rules_stream", "instance_name", "stream_id"),
],
sequence_name="push_rules_stream_sequence",
writers=hs.config.worker.writers.push_rules,
)

push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
Expand Down Expand Up @@ -880,6 +883,7 @@ def _insert_push_rules_update_txn(
raise Exception("Not a push writer")

values = {
"instance_name": self._instance_name,
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,
Expand Down
Loading

0 comments on commit 641b6b2

Please sign in to comment.