Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a thread relation type per MSC3440. #11088

Merged
merged 10 commits into from
Oct 21, 2021
1 change: 1 addition & 0 deletions changelog.d/11088.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
1 change: 1 addition & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class RelationTypes:
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
REFERENCE = "m.reference"
THREAD = "io.element.thread"


class LimitBlockingTypes:
Expand Down
2 changes: 2 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def read_config(self, config: JsonDict, **kwargs):

# Whether to enable experimental MSC1849 (aka relations) support
self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
# MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)

# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
Expand Down
17 changes: 17 additions & 0 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class EventClientSerializer:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled

async def serialize_event(
self,
Expand Down Expand Up @@ -462,6 +463,22 @@ async def serialize_event(
"sender": edit.sender,
}

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
) = await self.store.get_thread_summary(event_id)
if latest_thread_event:
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": await self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=False
),
"count": thread_count,
}

return serialized_event

async def serialize_events(
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ async def on_PUT_or_POST(

content["m.relates_to"] = {
"event_id": parent_id,
"key": aggregation_key,
"rel_type": relation_type,
}
if aggregation_key is not None:
content["m.relates_to"]["key"] = aggregation_key

event_dict = {
"type": event_type,
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,7 @@ def _handle_event_relations(self, txn, event):
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
RelationTypes.THREAD,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I removed this check entirely, but I'm unsure if that's the "right" thing to do or not.

If we removed this check then unknown / unstable / experimental / custom relations would be stored into the database. This seems good since you can then query them using /relations, but they would not appear in bundled aggregations since they're not understood.

Additionally if another field needs to be pulled out (e..g if a relation of type blah has a foo key that needs to be stored in the database) then it wouldn't be added to the table. I think storing all relations would still be an improvement as it would be easier to add support for blah via a database migration + a background task to fill a new foo column.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also -- I didn't put this behind a configuration flag since I think we want to store it regardless (kind of inline with the last paragraph above).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erikjohnston Any thoughts on if we should allow any relation type here or not?

I think we might also need a background update to find any of these events which occurred before someone upgrades. 😢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a super strong opinion. I think I would prefer to only handle "known" types of relations, as otherwise we don't know how to handle them, but OTOH that makes the background updates take a lot longer 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Well we'll need to do a background update regardless for this, so I'm not going to worry about it.

):
# Unknown relation type
return
Expand Down Expand Up @@ -1740,6 +1741,9 @@ def _handle_event_relations(self, txn, event):
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))

if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
Expand Down
59 changes: 58 additions & 1 deletion synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Optional
from typing import Optional, Tuple

import attr

Expand Down Expand Up @@ -269,6 +269,63 @@ def _get_applicable_edit_txn(txn):

return await self.get_event(edit_id, allow_none=True)

@cached()
async def get_thread_summary(
self, event_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.

Args:
event_id: The original event ID

Returns:
The number of items in the thread and the most recent response, if any.
"""

def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND relation_type = ?
clokep marked this conversation as resolved.
Show resolved Hide resolved
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""

txn.execute(sql, (event_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None

latest_event_id = row[0]

sql = """
SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations
WHERE
relates_to_id = ?
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
count = txn.fetchone()[0]
Comment on lines +307 to +315
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It came up in a discussion about whether this should consider event redactions or not.


return count, latest_event_id

count, latest_event_id = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True)

return count, latest_event

async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
Expand Down
40 changes: 34 additions & 6 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def test_deny_double_react(self):

def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotations require a key to be provided.

self.assertEquals(200, channel.code, channel.json_body)

channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]

Expand Down Expand Up @@ -141,8 +141,10 @@ def test_repeated_paginate_relations(self):
"""

expected_event_ids = []
for _ in range(10):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
for idx in range(10):
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
)
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])

Expand Down Expand Up @@ -386,8 +388,9 @@ def test_aggregation_must_be_annotation(self):
)
self.assertEquals(400, channel.code, channel.json_body)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_aggregation_get_event(self):
"""Test that annotations and references get correctly bundled when
"""Test that annotations, references, and threads get correctly bundled when
getting the parent event.
"""

Expand All @@ -410,6 +413,13 @@ def test_aggregation_get_event(self):
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]

channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)

channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]

channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
Expand All @@ -429,6 +439,25 @@ def test_aggregation_get_event(self):
RelationTypes.REFERENCE: {
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
},
RelationTypes.THREAD: {
"count": 2,
"latest_event": {
"age": 100,
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
},
"event_id": thread_2,
"origin_server_ts": 1600,
"room_id": self.room,
"sender": self.user_id,
"type": "m.room.test",
"unsigned": {"age": 100},
"user_id": self.user_id,
},
Comment on lines +444 to +459
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated about slimming down what fields we're checking against, most of them don't really matter here.

},
},
)

Expand Down Expand Up @@ -559,7 +588,6 @@ def test_edit_reply(self):
{
"m.relates_to": {
"event_id": self.parent_id,
"key": None,
"rel_type": "m.reference",
}
},
Expand Down