Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor filter_events_for_server #15240

Merged
merged 8 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `filter_events_for_server`.
2 changes: 2 additions & 0 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ async def _catch_up_transmission_loop(self) -> None:
self._server_name,
new_pdus,
redact=False,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)

# If we've filtered out all the extremities, fall back to
Expand Down
29 changes: 24 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,16 @@ async def _maybe_backfill_inner(
get_prev_content=False,
)

# We set `check_history_visibility_only` as we might otherwise get false
# We unset `filter_out_erased_senders` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self._storage_controllers,
self.server_name,
self.server_name,
events_to_check,
redact=False,
check_history_visibility_only=True,
filter_out_erased_senders=False,
filter_out_remote_partial_state_events=False,
)
if filtered_extremities:
extremities_to_request.append(bp.event_id)
Expand Down Expand Up @@ -1331,7 +1332,13 @@ async def on_backfill_request(
)

events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, events
self._storage_controllers,
origin,
self.server_name,
events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)

return events
Expand Down Expand Up @@ -1362,7 +1369,13 @@ async def get_persisted_pdu(
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)

events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, [event]
self._storage_controllers,
origin,
self.server_name,
[event],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
event = events[0]
return event
Expand Down Expand Up @@ -1390,7 +1403,13 @@ async def on_get_missing_events(
)

missing_events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, missing_events
self._storage_controllers,
origin,
self.server_name,
missing_events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)

return missing_events
Expand Down
67 changes: 47 additions & 20 deletions synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.
import logging
from enum import Enum, auto
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
from typing import (
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)

import attr
from typing_extensions import Final
Expand Down Expand Up @@ -565,29 +575,43 @@ async def filter_events_for_server(
storage: StorageControllers,
target_server_name: str,
local_server_name: str,
events: List[EventBase],
redact: bool = True,
check_history_visibility_only: bool = False,
events: Sequence[EventBase],
*,
redact: bool,
filter_out_erased_senders: bool,
filter_out_remote_partial_state_events: bool,
) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to
"""Filter a list of events based on whether the target server is allowed to
see them.

For a fully stated room, the target server is allowed to see an event E if:
- the state at E has world readable or shared history vis, OR
- the state at E says that the target server is in the room.

For a partially stated room, the target server is allowed to see E if:
- E was created by this homeserver, AND:
- the partial state at E has world readable or shared history vis, OR
- the partial state at E says that the target server is in the room.

TODO: state before or state after?

Args:
storage
server_name
target_server_name
local_server_name
events
redact: Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been
redact: Controls what to do with events which have been filtered out.
If True, include their redacted forms; if False, omit them entirely.
filter_out_erased_senders: If true, also filter out events whose sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.

filter_out_remote_partial_state_events: If True, also filter out events in
partial state rooms created by other homeservers.
Returns
The filtered events.
"""

def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
Expand Down Expand Up @@ -616,7 +640,7 @@ def check_event_is_visible(
# server has no users in the room: redact
return False

if not check_history_visibility_only:
if filter_out_erased_senders:
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
# We don't want to check whether users are erased, which is equivalent
Expand All @@ -631,15 +655,15 @@ def check_event_is_visible(
# otherwise a room could be fully joined after we retrieve those, which would then bypass
# this check but would base the filtering on an outdated view of the membership events.

partial_state_invisible_events = set()
if not check_history_visibility_only:
partial_state_invisible_event_ids: Set[str] = set()
if filter_out_remote_partial_state_events:
for e in events:
sender_domain = get_domain_from_id(e.sender)
if (
sender_domain != local_server_name
and await storage.main.is_partial_state_room(e.room_id)
):
partial_state_invisible_events.add(e)
partial_state_invisible_event_ids.add(e.event_id)

# Let's check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't
Expand All @@ -658,17 +682,20 @@ def check_event_is_visible(
target_server_name,
)

to_return = []
for e in events:
def include_event_in_output(e: EventBase) -> bool:
erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible(
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
)

if e in partial_state_invisible_events:
if e.event_id in partial_state_invisible_event_ids:
visible = False

if visible and not erased:
return visible and not erased

to_return = []
for e in events:
if include_event_in_output(e):
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
Expand Down
40 changes: 35 additions & 5 deletions tests/test_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def test_filtering(self) -> None:

filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", "hs", events_to_filter
self._storage_controllers,
"test_server",
"hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)

Expand All @@ -85,7 +91,13 @@ def test_filter_outlier(self) -> None:
self.assertEqual(
self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", "hs", [outlier]
self._storage_controllers,
"remote_hs",
"hs",
[outlier],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
),
[outlier],
Expand All @@ -96,7 +108,13 @@ def test_filter_outlier(self) -> None:

filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
self._storage_controllers,
"remote_hs",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
Expand All @@ -108,7 +126,13 @@ def test_filter_outlier(self) -> None:
# be redacted)
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "other_server", "local_hs", [outlier, evt]
self._storage_controllers,
"other_server",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(filtered[0], outlier)
Expand Down Expand Up @@ -143,7 +167,13 @@ def test_erased_user(self) -> None:
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers, "test_server", "local_hs", events_to_filter
self._storage_controllers,
"test_server",
"local_hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
)

Expand Down