Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Include whether the requesting user has participated in a thread. #11577

Merged
merged 6 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,10 @@ def _handle_event_relations(
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id),
)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Expand Down
66 changes: 44 additions & 22 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,23 +382,21 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:

@cached()
async def get_thread_summary(
self, event_id: str, room_id: str, user_id: str
) -> Tuple[int, Optional[EventBase], bool]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event.

Args:
event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
user_id: The user requesting the summary.

Returns:
The number of items in the thread and the most recent response, if any.
"""

def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str], bool]:
) -> Tuple[int, Optional[str]]:
# Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
Expand All @@ -416,7 +414,7 @@ def _get_thread_summary_txn(
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None, False
return 0, None

latest_event_id = row[0]

Expand All @@ -433,6 +431,37 @@ def _get_thread_summary_txn(
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = cast(Tuple[int], txn.fetchone())[0]

return count, latest_event_id

count, latest_event_id = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]

return count, latest_event

@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.

This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.

Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
user_id: The user requesting the summary.

Returns:
True if the requesting user participated in the thread, otherwise false.
"""

def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
Expand All @@ -446,20 +475,12 @@ def _get_thread_summary_txn(
"""

txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
participated = bool(txn.fetchone())

return count, latest_event_id, participated
return bool(txn.fetchone())

count, latest_event_id, participated = await self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]

return count, latest_event, participated

async def events_have_relations(
self,
parent_ids: List[str],
Expand Down Expand Up @@ -616,11 +637,12 @@ async def _get_bundled_aggregation_for_event(

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
participated,
) = await self.get_thread_summary(event_id, room_id, user_id)
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
"latest_event": latest_thread_event,
Expand Down