Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Faster joins: filter out non local events when a room doesn't have its full state #14404

Merged
merged 9 commits into from
Nov 21, 2022
1 change: 1 addition & 0 deletions changelog.d/14404.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Faster joins: filter out non local events when a room doesn't have its full state.
1 change: 1 addition & 0 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ async def _catch_up_transmission_loop(self) -> None:
new_pdus = await filter_events_for_server(
self._storage_controllers,
self._destination,
self._server_name,
new_pdus,
redact=False,
)
Expand Down
7 changes: 4 additions & 3 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ async def _maybe_backfill_inner(
filtered_extremities = await filter_events_for_server(
self._storage_controllers,
self.server_name,
self.server_name,
events_to_check,
redact=False,
check_history_visibility_only=True,
Expand Down Expand Up @@ -1252,7 +1253,7 @@ async def on_backfill_request(
)

events = await filter_events_for_server(
self._storage_controllers, origin, events
self._storage_controllers, origin, self.server_name, events
)

return events
Expand Down Expand Up @@ -1283,7 +1284,7 @@ async def get_persisted_pdu(
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)

events = await filter_events_for_server(
self._storage_controllers, origin, [event]
self._storage_controllers, origin, self.server_name, [event]
)
event = events[0]
return event
Expand All @@ -1309,7 +1310,7 @@ async def on_get_missing_events(
)

missing_events = await filter_events_for_server(
self._storage_controllers, origin, missing_events
self._storage_controllers, origin, self.server_name, missing_events
)

return missing_events
Expand Down
18 changes: 15 additions & 3 deletions synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:

async def filter_events_for_server(
storage: StorageControllers,
server_name: str,
target_server_name: str,
local_server_name: str,
events: List[EventBase],
redact: bool = True,
check_history_visibility_only: bool = False,
Expand Down Expand Up @@ -603,7 +604,7 @@ def check_event_is_visible(
# if the server is either in the room or has been invited
# into the room.
for ev in memberships.values():
assert get_domain_from_id(ev.state_key) == server_name
assert get_domain_from_id(ev.state_key) == target_server_name

memtype = ev.membership
if memtype == Membership.JOIN:
Expand Down Expand Up @@ -636,7 +637,7 @@ def check_event_is_visible(
if event_to_history_vis[e.event_id]
not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
],
server_name,
target_server_name,
)

to_return = []
Expand All @@ -645,6 +646,17 @@ def check_event_is_visible(
visible = check_event_is_visible(
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
)

# Filter out non-local events when we are in the middle of a partial join,
# since our servers list can be out of date and we could leak events
# to servers not in the room anymore.
# This can also be true for local events but we consider it to be
# an acceptable risk in this case.
if e.origin != local_server_name and await storage.main.is_partial_state_room(
e.room_id
):
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
visible = False

if visible and not erased:
to_return.append(e)
elif redact:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_filtering(self) -> None:

filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
self._storage_controllers, "test_server", "hs", events_to_filter
)
)

Expand All @@ -83,7 +83,7 @@ def test_filter_outlier(self) -> None:
self.assertEqual(
self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier]
self._storage_controllers, "remote_hs", "hs", [outlier]
)
),
[outlier],
Expand All @@ -94,7 +94,7 @@ def test_filter_outlier(self) -> None:

filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier, evt]
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
Expand All @@ -106,7 +106,7 @@ def test_filter_outlier(self) -> None:
# be redacted)
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "other_server", [outlier, evt]
self._storage_controllers, "other_server", "local_hs", [outlier, evt]
)
)
self.assertEqual(filtered[0], outlier)
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_erased_user(self) -> None:
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
self._storage_controllers, "test_server", "local_hs", events_to_filter
)
)

Expand Down