Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Aggregate unread notif count query for badge count calculation #14255

Merged
merged 22 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14255.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar).
28 changes: 9 additions & 19 deletions synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.util.async_helpers import concurrently_execute


async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
Expand All @@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -

badge = len(invites)

room_notifs = []

async def get_room_unread_count(room_id: str) -> None:
room_notifs.append(
await store.get_unread_event_push_actions_by_room_for_user(
room_id,
user_id,
)
)

await concurrently_execute(get_room_unread_count, joins, 10)

for notifs in room_notifs:
# Combine the counts from all the threads.
notify_count = notifs.main_timeline.notify_count + sum(
n.notify_count for n in notifs.threads.values()
)
room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
for room_id, notify_count in room_to_count.items():
# room_to_count may include rooms which the user has left,
# ignore those.
if room_id not in joins:
continue
Fizzadar marked this conversation as resolved.
Show resolved Hide resolved

if notify_count == 0:
continue
Expand All @@ -51,8 +39,10 @@ async def get_room_unread_count(room_id: str) -> None:
# return one badge count per conversation
badge += 1
else:
# increment the badge count by the number of unread messages in the room
# Increase badge by number of notifications in room
# NOTE: this includes threaded and unthreaded notifications.
badge += notify_count

return badge


Expand Down
149 changes: 149 additions & 0 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"""

import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
Expand All @@ -95,6 +96,7 @@
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
Expand Down Expand Up @@ -463,6 +465,153 @@ def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:

return result

async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.

This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.

Note that this function assumes the user is a member of the room. Because
summary rows are not removed when a user leaves a room, the caller must
filter out those results from the result.
clokep marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A map of room ID to notification counts for the given user.
"""
Fizzadar marked this conversation as resolved.
Show resolved Hide resolved
return await self.db_pool.runInteraction(
"get_unread_counts_by_room_for_user",
self._get_unread_counts_by_room_for_user_txn,
user_id,
)

def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
args.extend([user_id, user_id])

receipts_cte = f"""
WITH all_receipts AS (
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
{receipt_types_clause}
AND user_id = ?
GROUP BY room_id, thread_id
)
"""

receipts_joins = """
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS threaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NOT NULL
) AS threaded_receipts USING (room_id, thread_id)
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NULL
) AS unthreaded_receipts USING (room_id)
"""

# First get summary counts by room / thread for the user. We use the max receipt
# stream ordering of both threaded & unthreaded receipts to compare against the
# summary table.
#
# PostgreSQL and SQLite differ in comparing scalar numerics.
if isinstance(self.database_engine, PostgresEngine):
# GREATEST ignores NULLs.
max_clause = """GREATEST(
threaded_receipt_stream_ordering,
unthreaded_receipt_stream_ordering
)"""
else:
# MAX returns NULL if any are NULL, so COALESCE to 0 first.
max_clause = """MAX(
COALESCE(threaded_receipt_stream_ordering, 0),
COALESCE(unthreaded_receipt_stream_ordering, 0)
)"""

sql = f"""
{receipts_cte}
SELECT eps.room_id, eps.thread_id, notif_count
FROM event_push_summary AS eps
{receipts_joins}
WHERE user_id = ?
AND notif_count != 0
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
OR last_receipt_stream_ordering = {max_clause}
)
"""
txn.execute(sql, args)

seen_thread_ids = set()
room_to_count: Dict[str, int] = defaultdict(int)

for room_id, thread_id, notif_count in txn:
room_to_count[room_id] += notif_count
seen_thread_ids.add(thread_id)

# Now get any event push actions that haven't been rotated using the same OR
# join and filter by receipt and event push summary rotated up to stream ordering.
sql = f"""
{receipts_cte}
SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND epa.notif = 1
AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id, epa.thread_id
"""
txn.execute(sql, args)

for room_id, thread_id, notif_count in txn:
# Note: only count push actions we have valid summaries for with up to date receipt.
if thread_id not in seen_thread_ids:
continue
Comment on lines +581 to +583
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this skip any threads which take place wholly after the latest summary? (I'm unsure why we care about whether we've seen the thread or not -- I thought the SQL query was only returning things with valid summaries in the first place?)

Copy link
Contributor Author

@Fizzadar Fizzadar Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query after this one captures threads that only exist post-summary, this query is specifically to capture items in event_push_actions that do have valid summaries as well. This (all three queries in this method) was based off the thee single room calculations here:

# First we pull the counts from the summary table.
#
# We check that `last_receipt_stream_ordering` matches the stream ordering of the
# latest receipt for the thread (which may be either the unthreaded read receipt
# or the threaded read receipt).
#
# If it doesn't match then a new read receipt has arrived and we haven't yet
# updated the counts in `event_push_summary` to reflect that; in that case we
# simply ignore `event_push_summary` counts.
#
# We then do a manual count of all the rows in the `event_push_actions` table
# for any user/room/thread which did not have a valid summary found.
#
# If `last_receipt_stream_ordering` is null then that means it's up-to-date
# (as the row was written by an older version of Synapse that
# updated `event_push_summary` synchronously when persisting a new read
# receipt).
txn.execute(
f"""
SELECT notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary
LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND {receipt_types_clause}
GROUP BY thread_id
) AS receipts USING (thread_id)
WHERE room_id = ? AND user_id = ?
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?))
OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?)
) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""",
(
user_id,
room_id,
unthreaded_receipt_stream_ordering,
*receipts_args,
room_id,
user_id,
unthreaded_receipt_stream_ordering,
unthreaded_receipt_stream_ordering,
),
)
summarised_threads = set()
for notif_count, unread_count, thread_id in txn:
summarised_threads.add(thread_id)
counts = _get_thread(thread_id)
counts.notify_count += notif_count
counts.unread_count += unread_count
# Next we need to count highlights, which aren't summarised
sql = f"""
SELECT COUNT(*), thread_id FROM event_push_actions
LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND {receipt_types_clause}
GROUP BY thread_id
) AS receipts USING (thread_id)
WHERE user_id = ?
AND room_id = ?
AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
AND highlight = 1
GROUP BY thread_id
"""
txn.execute(
sql,
(
user_id,
room_id,
unthreaded_receipt_stream_ordering,
*receipts_args,
user_id,
room_id,
unthreaded_receipt_stream_ordering,
),
)
for highlight_count, thread_id in txn:
_get_thread(thread_id).highlight_count += highlight_count
# For threads which were summarised we need to count actions since the last
# rotation.
thread_id_clause, thread_id_args = make_in_list_sql_clause(
self.database_engine, "thread_id", summarised_threads
)
# The (inclusive) event stream ordering that was previously summarised.
rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, rotated_upto_stream_ordering
)
for notif_count, unread_count, thread_id in unread_counts:
if thread_id not in summarised_threads:
continue
if thread_id == MAIN_TIMELINE:
counts.notify_count += notif_count
counts.unread_count += unread_count
elif thread_id in thread_counts:
thread_counts[thread_id].notify_count += notif_count
thread_counts[thread_id].unread_count += unread_count
else:
# Previous thread summaries of 0 are discarded above.
#
# TODO If empty summaries are deleted this can be removed.
thread_counts[thread_id] = NotifCounts(
notify_count=notif_count,
unread_count=unread_count,
highlight_count=0,
)
# Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read
# receipt.
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END),
thread_id
FROM event_push_actions
LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND {receipt_types_clause}
GROUP BY thread_id
) AS receipts USING (thread_id)
WHERE user_id = ?
AND room_id = ?
AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
AND NOT {thread_id_clause}
GROUP BY thread_id
"""

room_to_count[room_id] += notif_count

thread_id_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine, "epa.thread_id", seen_thread_ids
)

# Finally re-check event_push_actions for any rooms not in the summary, ignoring
# the rotated up-to position. This handles the case where a read receipt has arrived
# but not been rotated meaning the summary table is out of date, so we go back to
# the push actions table.
sql = f"""
{receipts_cte}
SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND NOT {thread_id_clause}
AND epa.notif = 1
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id
"""

args.extend(thread_ids_args)
txn.execute(sql, args)

for room_id, notif_count in txn:
room_to_count[room_id] += notif_count

return room_to_count

@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
Expand Down
47 changes: 39 additions & 8 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_count_aggregation(self) -> None:

last_event_id: str

def _assert_counts(noitf_count: int, highlight_count: int) -> None:
def _assert_counts(notif_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
Expand All @@ -168,13 +168,22 @@ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
)
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(aggregate_counts[room_id], notif_count)

def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
room_id,
Expand Down Expand Up @@ -283,7 +292,7 @@ def test_count_aggregation_threads(self) -> None:
last_event_id: str

def _assert_counts(
noitf_count: int,
notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
Expand All @@ -299,7 +308,7 @@ def _assert_counts(
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
Expand All @@ -318,6 +327,17 @@ def _assert_counts(
else:
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)

def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
Expand Down Expand Up @@ -454,7 +474,7 @@ def test_count_aggregation_mixed(self) -> None:
last_event_id: str

def _assert_counts(
noitf_count: int,
notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
Expand All @@ -470,7 +490,7 @@ def _assert_counts(
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
Expand All @@ -489,6 +509,17 @@ def _assert_counts(
else:
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)

def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
Expand Down Expand Up @@ -646,7 +677,7 @@ def _create_event(type: str, content: JsonDict) -> str:
)
return result["event_id"]

def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
Expand All @@ -658,7 +689,7 @@ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
notify_count=notif_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
Expand Down