Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Do not allow cross-room relations, per MSC2674. (#11516)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Dec 9, 2021
1 parent 0cc3bf9 commit 3b88722
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog.d/11516.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event.
11 changes: 7 additions & 4 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,23 +454,26 @@ async def _injected_bundled_aggregations(
return

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include.
aggregations = {}

annotations = await self.store.get_aggregation_groups_for_event(event_id)
annotations = await self.store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()

references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()

edit = None
if event.type == EventTypes.Message:
edit = await self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id, room_id)

if edit:
# If there is an edit replace the content, preserving existing
Expand Down Expand Up @@ -503,7 +506,7 @@ async def _injected_bundled_aggregations(
(
thread_count,
latest_thread_event,
) = await self.store.get_thread_summary(event_id)
) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
Expand Down
7 changes: 6 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ async def on_GET(

pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
Expand Down Expand Up @@ -317,6 +318,7 @@ async def on_GET(

pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
Expand Down Expand Up @@ -383,7 +385,9 @@ async def on_GET(

# This checks that a) the event exists and b) the user is allowed to
# view it.
await self.event_handler.get_event(requester.user, room_id, parent_id)
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")

if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
Expand All @@ -402,6 +406,7 @@ async def on_GET(

result = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,10 +1780,14 @@ def _handle_event_relations(
)

if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
txn.call_after(
self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
)

if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Expand Down
36 changes: 26 additions & 10 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event(
self,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
Expand All @@ -49,6 +50,7 @@ async def get_relations_for_event(
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
Expand All @@ -63,8 +65,8 @@ async def get_relations_for_event(
the form `{"event_id": "..."}`.
"""

where_clause = ["relates_to_id = ?"]
where_args: List[Union[str, int]] = [event_id]
where_clause = ["relates_to_id = ?", "room_id = ?"]
where_args: List[Union[str, int]] = [event_id, room_id]

if relation_type is not None:
where_clause.append("relation_type = ?")
Expand Down Expand Up @@ -199,6 +201,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
async def get_aggregation_groups_for_event(
self,
event_id: str,
room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
Expand All @@ -213,6 +216,7 @@ async def get_aggregation_groups_for_event(
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
Expand All @@ -225,8 +229,12 @@ async def get_aggregation_groups_for_event(
`type`, `key` and `count` fields.
"""

where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]

if event_type:
where_clause.append("type = ?")
Expand Down Expand Up @@ -288,14 +296,17 @@ def _get_aggregation_groups_for_event_txn(
)

@cached()
async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
async def get_applicable_edit(
self, event_id: str, room_id: str
) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
Correctly handles checking whether edits were allowed to happen.
Args:
event_id: The original event ID
room_id: The original event's room ID
Returns:
The most recent edit, if any.
Expand All @@ -317,13 +328,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
WHERE
relates_to_id = ?
AND relation_type = ?
AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""

def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
Expand All @@ -340,13 +352,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:

@cached()
async def get_thread_summary(
self, event_id: str
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
Args:
event_id: The original event ID
event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
Returns:
The number of items in the thread and the most recent response, if any.
Expand All @@ -363,12 +376,13 @@ def _get_thread_summary_txn(
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""

txn.execute(sql, (event_id, RelationTypes.THREAD))
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
Expand All @@ -378,11 +392,13 @@ def _get_thread_summary_txn(
sql = """
SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = txn.fetchone()[0] # type: ignore[index]

return count, latest_event_id
Expand Down
115 changes: 115 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch

from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync

from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event


class RelationsTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -651,6 +654,118 @@ def test_aggregation_get_event_for_thread(self):
},
)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self):
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
res = self.helper.send(room2, body="Hi!", tok=self.user_token)
parent_id = res["event_id"]

# Disable the validation to pretend this came over federation.
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None),
):
# Generate a various relations from a different room.
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.reaction",
sender=self.user_id,
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": parent_id,
"key": "A",
}
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.REFERENCE,
"event_id": parent_id,
},
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": parent_id,
},
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"new_content": {
"body": "new content",
"msgtype": "m.text",
},
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": parent_id,
},
},
)
)

# They should be ignored when fetching relations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])

# And when fetching aggregations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])

# And for bundled aggregations.
channel = self.make_request(
"GET",
f"/rooms/{room2}/event/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])

def test_edit(self):
"""Test that a simple edit works."""

Expand Down

0 comments on commit 3b88722

Please sign in to comment.