Skip to content

Commit

Permalink
Sliding sync: Add connection tracking to the account_data extension (
Browse files Browse the repository at this point in the history
…element-hq#17695)

This is basically exactly the same logic as for receipts. Essentially we
just need to track which room account data we have and haven't sent down
to clients, and use that when we pull stuff out.

I think this just needs a couple of extra tests written

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
  • Loading branch information
erikjohnston and MadLittleMods authored Sep 19, 2024
1 parent c2e5e9e commit a851f6b
Show file tree
Hide file tree
Showing 7 changed files with 742 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/17695.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where room account data would not correctly be sent down sliding sync for old rooms.
183 changes: 131 additions & 52 deletions synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AbstractSet,
ChainMap,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Expand Down Expand Up @@ -119,6 +118,8 @@ async def get_extensions_response(
if sync_config.extensions.account_data is not None:
account_data_response = await self.get_account_data_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
account_data_request=sync_config.extensions.account_data,
Expand Down Expand Up @@ -361,6 +362,8 @@ async def get_e2ee_extension_response(
async def get_account_data_extension_response(
self,
sync_config: SlidingSyncConfig,
previous_connection_state: "PerConnectionState",
new_connection_state: "MutablePerConnectionState",
actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: Set[str],
account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
Expand Down Expand Up @@ -425,25 +428,51 @@ async def get_account_data_extension_response(

# Fetch room account data
#
# List of -> Mapping from room_id to mapping of `type` to `content` of room
# account data events.
#
# This is is a list so we can avoid making copies of immutable data and instead
# just provide multiple maps that need to be combined. Normally, we could
# reach for `ChainMap` in this scenario, but this is a nested map and accessing
# the ChainMap by room_id won't combine the two maps for that room (we would
# need a new `NestedChainMap` type class).
account_data_by_room_maps: List[Mapping[str, Mapping[str, JsonMapping]]] = []
account_data_by_room_map: MutableMapping[str, Mapping[str, JsonMapping]] = {}
relevant_room_ids = self.find_relevant_room_ids_for_extension(
requested_lists=account_data_request.lists,
requested_room_ids=account_data_request.rooms,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
)
if len(relevant_room_ids) > 0:
# We need to handle the different cases depending on if we have sent
# down account data previously or not, so we split the relevant
# rooms up into different collections based on status.
live_rooms = set()
previously_rooms: Dict[str, int] = {}
initial_rooms = set()

for room_id in relevant_room_ids:
if not from_token:
initial_rooms.add(room_id)
continue

room_status = previous_connection_state.account_data.have_sent_room(
room_id
)
if room_status.status == HaveSentRoomFlag.LIVE:
live_rooms.add(room_id)
elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
assert room_status.last_token is not None
previously_rooms[room_id] = room_status.last_token
elif room_status.status == HaveSentRoomFlag.NEVER:
initial_rooms.add(room_id)
else:
assert_never(room_status.status)

# We fetch all room account data since the from_token. This is so
# that we can record which rooms have updates that haven't been sent
# down.
#
# Mapping from room_id to mapping of `type` to `content` of room account
# data events.
all_updates_since_the_from_token: Mapping[
str, Mapping[str, JsonMapping]
] = {}
if from_token is not None:
# TODO: This should take into account the `from_token` and `to_token`
account_data_by_room_map = (
all_updates_since_the_from_token = (
await self.store.get_updated_room_account_data_for_user(
user_id, from_token.stream_token.account_data_key
)
Expand All @@ -456,58 +485,108 @@ async def get_account_data_extension_response(
user_id, from_token.stream_token.account_data_key
)
for room_id, tags in tags_by_room.items():
account_data_by_room_map.setdefault(room_id, {})[
all_updates_since_the_from_token.setdefault(room_id, {})[
AccountDataTypes.TAG
] = {"tags": tags}

account_data_by_room_maps.append(account_data_by_room_map)
else:
# TODO: This should take into account the `to_token`
immutable_account_data_by_room_map = (
await self.store.get_room_account_data_for_user(user_id)
)
account_data_by_room_maps.append(immutable_account_data_by_room_map)
# For live rooms we just get the updates from `all_updates_since_the_from_token`
if live_rooms:
for room_id in all_updates_since_the_from_token.keys() & live_rooms:
account_data_by_room_map[room_id] = (
all_updates_since_the_from_token[room_id]
)

# Add room tags
#
# TODO: This should take into account the `to_token`
tags_by_room = await self.store.get_tags_for_user(user_id)
account_data_by_room_maps.append(
{
room_id: {AccountDataTypes.TAG: {"tags": tags}}
for room_id, tags in tags_by_room.items()
}
# For previously and initial rooms we query each room individually.
if previously_rooms or initial_rooms:

async def handle_previously(room_id: str) -> None:
# Either get updates or all account data in the room
# depending on if the room state is PREVIOUSLY or NEVER.
previous_token = previously_rooms.get(room_id)
if previous_token is not None:
room_account_data = await (
self.store.get_updated_room_account_data_for_user_for_room(
user_id=user_id,
room_id=room_id,
from_stream_id=previous_token,
to_stream_id=to_token.account_data_key,
)
)

# Add room tags
changed = await self.store.has_tags_changed_for_room(
user_id=user_id,
room_id=room_id,
from_stream_id=previous_token,
to_stream_id=to_token.account_data_key,
)
if changed:
# XXX: Ideally, this should take into account the `to_token`
# and return the set of tags at that time but we don't track
# changes to tags so we just have to return all tags for the
# room.
immutable_tag_map = await self.store.get_tags_for_room(
user_id, room_id
)
room_account_data[AccountDataTypes.TAG] = {
"tags": immutable_tag_map
}

# Only add an entry if there were any updates.
if room_account_data:
account_data_by_room_map[room_id] = room_account_data
else:
# TODO: This should take into account the `to_token`
immutable_room_account_data = (
await self.store.get_account_data_for_room(user_id, room_id)
)

# Add room tags
#
# XXX: Ideally, this should take into account the `to_token`
# and return the set of tags at that time but we don't track
# changes to tags so we just have to return all tags for the
# room.
immutable_tag_map = await self.store.get_tags_for_room(
user_id, room_id
)

account_data_by_room_map[room_id] = ChainMap(
{AccountDataTypes.TAG: {"tags": immutable_tag_map}}
if immutable_tag_map
else {},
# Cast is safe because `ChainMap` only mutates the top-most map,
# see https://github.com/python/typeshed/issues/8430
cast(
MutableMapping[str, JsonMapping],
immutable_room_account_data,
),
)

# We handle these rooms concurrently to speed it up.
await concurrently_execute(
handle_previously,
previously_rooms.keys() | initial_rooms,
limit=20,
)

# Filter down to the relevant rooms ... and combine the maps
relevant_account_data_by_room_map: MutableMapping[
str, Mapping[str, JsonMapping]
] = {}
for room_id in relevant_room_ids:
# We want to avoid adding empty maps for relevant rooms that have no room
# account data so do a quick check to see if it's in any of the maps.
is_room_in_maps = False
for room_map in account_data_by_room_maps:
if room_id in room_map:
is_room_in_maps = True
break
# Now record which rooms are now up to data, and which rooms have
# pending updates to send.
new_connection_state.account_data.record_sent_rooms(relevant_room_ids)
missing_updates = (
all_updates_since_the_from_token.keys() - relevant_room_ids
)
if missing_updates:
# If we have missing updates then we must have had a from_token.
assert from_token is not None

# If we found the room in any of the maps, combine the maps for that room
if is_room_in_maps:
relevant_account_data_by_room_map[room_id] = ChainMap(
{},
*(
# Cast is safe because `ChainMap` only mutates the top-most map,
# see https://github.com/python/typeshed/issues/8430
cast(MutableMapping[str, JsonMapping], room_map[room_id])
for room_map in account_data_by_room_maps
if room_map.get(room_id)
),
new_connection_state.account_data.record_unsent_rooms(
missing_updates, from_token.stream_token.account_data_key
)

return SlidingSyncResult.Extensions.AccountDataExtension(
global_account_data_map=global_account_data_map,
account_data_by_room_map=relevant_account_data_by_room_map,
account_data_by_room_map=account_data_by_room_map,
)

@trace
Expand Down
50 changes: 50 additions & 0 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,56 @@ def get_updated_room_account_data_for_user_txn(
get_updated_room_account_data_for_user_txn,
)

async def get_updated_room_account_data_for_user_for_room(
self,
# Since there are multiple arguments with the same type, force keyword arguments
# so people don't accidentally swap the order
*,
user_id: str,
room_id: str,
from_stream_id: int,
to_stream_id: int,
) -> Dict[str, JsonMapping]:
"""Get the room account_data that's changed for a user in a room.
(> `from_stream_id` and <= `to_stream_id`)
Args:
user_id: The user to get the account_data for.
room_id: The room to check
from_stream_id: The point in the stream to fetch from
to_stream_id: The point in the stream to fetch to
Returns:
A dict of the room account data.
"""

def get_updated_room_account_data_for_user_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonMapping]:
sql = """
SELECT account_data_type, content FROM room_account_data
WHERE user_id = ? AND room_id = ? AND stream_id > ? AND stream_id <= ?
"""
txn.execute(sql, (user_id, room_id, from_stream_id, to_stream_id))

room_account_data: Dict[str, JsonMapping] = {}
for row in txn:
room_account_data[row[0]] = db_to_json(row[1])

return room_account_data

changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(from_stream_id)
)
if not changed:
return {}

return await self.db_pool.runInteraction(
"get_updated_room_account_data_for_user_for_room",
get_updated_room_account_data_for_user_for_room_txn,
)

@cached(max_entries=5000, iterable=True)
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Expand Down
37 changes: 37 additions & 0 deletions synapse/storage/databases/main/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ def persist_per_connection_state_txn(
(have_sent_room.status.value, have_sent_room.last_token)
)

for (
room_id,
have_sent_room,
) in per_connection_state.account_data._statuses.items():
key_values.append((connection_position, "account_data", room_id))
value_values.append(
(have_sent_room.status.value, have_sent_room.last_token)
)

self.db_pool.simple_upsert_many_txn(
txn,
table="sliding_sync_connection_streams",
Expand Down Expand Up @@ -407,6 +416,7 @@ def _get_and_clear_connection_positions_txn(
# Now look up the per-room stream data.
rooms: Dict[str, HaveSentRoom[str]] = {}
receipts: Dict[str, HaveSentRoom[str]] = {}
account_data: Dict[str, HaveSentRoom[str]] = {}

receipt_rows = self.db_pool.simple_select_list_txn(
txn,
Expand All @@ -427,6 +437,8 @@ def _get_and_clear_connection_positions_txn(
rooms[room_id] = have_sent_room
elif stream == "receipts":
receipts[room_id] = have_sent_room
elif stream == "account_data":
account_data[room_id] = have_sent_room
else:
# For forwards compatibility we ignore unknown streams, as in
# future we want to be able to easily add more stream types.
Expand All @@ -435,6 +447,7 @@ def _get_and_clear_connection_positions_txn(
return PerConnectionStateDB(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
account_data=RoomStatusMap(account_data),
room_configs=room_configs,
)

Expand All @@ -452,6 +465,7 @@ class PerConnectionStateDB:

rooms: "RoomStatusMap[str]"
receipts: "RoomStatusMap[str]"
account_data: "RoomStatusMap[str]"

room_configs: Mapping[str, "RoomSyncConfig"]

Expand Down Expand Up @@ -484,17 +498,29 @@ async def from_state(
for room_id, status in per_connection_state.receipts.get_updates().items()
}

account_data = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
str(status.last_token) if status.last_token is not None else None
),
)
for room_id, status in per_connection_state.account_data.get_updates().items()
}

log_kv(
{
"rooms": rooms,
"receipts": receipts,
"account_data": account_data,
"room_configs": per_connection_state.room_configs.maps[0],
}
)

return PerConnectionStateDB(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
account_data=RoomStatusMap(account_data),
room_configs=per_connection_state.room_configs.maps[0],
)

Expand Down Expand Up @@ -524,8 +550,19 @@ async def to_state(self, store: "DataStore") -> "PerConnectionState":
for room_id, status in self.receipts._statuses.items()
}

account_data = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
int(status.last_token) if status.last_token is not None else None
),
)
for room_id, status in self.account_data._statuses.items()
}

return PerConnectionState(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
account_data=RoomStatusMap(account_data),
room_configs=self.room_configs,
)
Loading

0 comments on commit a851f6b

Please sign in to comment.