Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge account data streams (#14826)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Jan 13, 2023
1 parent 1416096 commit 73ff493
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 83 deletions.
1 change: 1 addition & 0 deletions changelog.d/14826.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Merge tag and normal account data replication streams.
12 changes: 12 additions & 0 deletions docs/upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
# Upgrading to v1.76.0
## Changes to the account data replication streams
Synapse has changed the format of the account data replication streams (between
workers). This is a forwards- and backwards-incompatible change: v1.75 workers
cannot process account data replicated by v1.76 workers, and vice versa.
Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data
replication will resume as normal.
# Upgrading to v1.74.0
## Unicode support in user search
Expand Down
1 change: 1 addition & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes:
DIRECT: Final = "m.direct"
IGNORED_USER_LIST: Final = "m.ignored_user_list"
TAG: Final = "m.tag"


class HistoryVisibility:
Expand Down
7 changes: 6 additions & 1 deletion synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple

from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
ReplicationAddRoomAccountDataRestServlet,
ReplicationAddTagRestServlet,
Expand Down Expand Up @@ -335,7 +336,11 @@ async def get_new_events(

for room_id, room_tags in tags.items():
results.append(
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
{
"type": AccountDataTypes.TAG,
"content": {"tags": room_tags},
"room_id": room_id,
}
)

(
Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
Expand Down Expand Up @@ -239,7 +239,7 @@ async def handle_room(event: RoomsForUser) -> None:
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append(
{"type": "m.tag", "content": {"tags": tags}}
{"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)

account_data = account_data_by_room.get(event.room_id, {})
Expand Down Expand Up @@ -326,7 +326,9 @@ async def room_initial_sync(
account_data_events = []
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data_events.append(
{"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)

account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
Expand Down
11 changes: 9 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
import attr
from prometheus_client import Counter

from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
Expand Down Expand Up @@ -2331,7 +2336,9 @@ async def _generate_room_entry(

account_data_events = []
if tags is not None:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data_events.append(
{"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)

for account_data_type, content in account_data.items():
account_data_events.append(
Expand Down
3 changes: 1 addition & 2 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
PushersStream,
PushRulesStream,
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
UnPartialStatedEventStream,
Expand Down Expand Up @@ -168,7 +167,7 @@ async def on_rdata(
self.notifier.on_new_event(
StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
)
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
elif stream_name in AccountDataStream.NAME:
self.notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
Expand Down
3 changes: 1 addition & 2 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
PresenceStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
Expand Down Expand Up @@ -145,7 +144,7 @@ def __init__(self, hs: "HomeServer"):

continue

if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
if isinstance(stream, AccountDataStream):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
Expand Down
3 changes: 0 additions & 3 deletions synapse/replication/tcp/streams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PushRulesStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
Expand All @@ -62,7 +61,6 @@
DeviceListsStream,
ToDeviceStream,
FederationStream,
TagAccountDataStream,
AccountDataStream,
UserSignatureStream,
UnPartialStatedRoomStream,
Expand All @@ -83,7 +81,6 @@
"CachesStream",
"DeviceListsStream",
"ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream",
"UserSignatureStream",
"UnPartialStatedRoomStream",
Expand Down
49 changes: 24 additions & 25 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

import attr

from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -495,27 +495,6 @@ def __init__(self, hs: "HomeServer"):
)


class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
class TagAccountDataStreamRow:
user_id: str
room_id: str
data: JsonDict

NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow

def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
store.get_all_updated_tags,
)


class AccountDataStream(Stream):
"""Global or per room account data was changed"""

Expand Down Expand Up @@ -560,6 +539,19 @@ async def _update_function(
to_token = room_results[-1][0]
limited = True

tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
instance_name,
from_token,
to_token,
limit,
)

# again, if the tag results hit the limit, limit the global results to
# the same stream token.
if tags_limited:
to_token = tag_to_token
limited = True

# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
Expand All @@ -568,11 +560,16 @@ async def _update_function(
if stream_id <= to_token
)

# we know that the room_results are already limited to `to_token` so no need
# for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
if stream_id <= to_token
)

tag_rows = (
(stream_id, (user_id, room_id, AccountDataTypes.TAG))
for stream_id, user_id, room_id in tags
if stream_id <= to_token
)

# We need to return a sorted list, so merge them together.
Expand All @@ -582,7 +579,9 @@ async def _update_function(
# leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`.
updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
updates = list(
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
return updates, to_token, limited


Expand Down
6 changes: 2 additions & 4 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -454,9 +454,7 @@ def process_replication_rows(
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

Expand Down
54 changes: 13 additions & 41 deletions synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast

from synapse.replication.tcp.streams import TagAccountDataStream
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
Expand Down Expand Up @@ -54,7 +55,7 @@ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]

async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
Expand All @@ -73,7 +74,7 @@ async def get_all_updated_tags(
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
The updates are a list of tuples of stream ID, user ID and room ID
"""

if last_id == current_id:
Expand All @@ -96,38 +97,13 @@ def get_all_updated_tags_txn(
"get_all_updated_tags", get_all_updated_tags_txn
)

def get_tag_content(
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))

return results

batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = await self.db_pool.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)

limited = False
upto_token = current_id
if len(results) >= limit:
upto_token = results[-1][0]
if len(tag_ids) >= limit:
upto_token = tag_ids[-1][0]
limited = True

return results, upto_token, limited
return tag_ids, upto_token, limited

async def get_updated_tags(
self, user_id: str, stream_id: int
Expand Down Expand Up @@ -299,20 +275,16 @@ def process_replication_rows(
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
if stream_name == AccountDataStream.NAME:
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
if row.data_type == AccountDataTypes.TAG:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)


class TagsStore(TagsWorkerStore):
pass

0 comments on commit 73ff493

Please sign in to comment.