Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve tests for get_unread_push_actions_for_user_in_range #13893

Merged
merged 5 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13893.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
38 changes: 24 additions & 14 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,18 @@ def f(txn: LoggingTransaction) -> List[str]:

def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
) -> List[Tuple[str, int]]:
) -> Dict[str, int]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.

Args:
txn:
user_id: The user to fetch receipts for.

Returns:
A map of room ID to stream ordering for all rooms the user has a receipt in.
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
Expand All @@ -580,7 +591,10 @@ def _get_receipts_by_room_txn(

args.extend((user_id,))
txn.execute(sql, args)
return cast(List[Tuple[str, int]], txn.fetchall())
return {
room_id: latest_stream_ordering
for room_id, latest_stream_ordering in txn.fetchall()
}

async def get_unread_push_actions_for_user_in_range_for_http(
self,
Expand All @@ -605,12 +619,10 @@ async def get_unread_push_actions_for_user_in_range_for_http(
The list will have between 0~limit entries.
"""

receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
Expand Down Expand Up @@ -679,12 +691,10 @@ async def get_unread_push_actions_for_user_in_range_for_email(
The list will have between 0~limit entries.
"""

receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
Expand Down
88 changes: 72 additions & 16 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest import admin
Expand All @@ -22,8 +24,6 @@

from tests.unittest import HomeserverTestCase

USER_ID = "@user:example.com"


class EventPushActionsStoreTestCase(HomeserverTestCase):
servlets = [
Expand All @@ -38,21 +38,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
assert persist_events_store is not None
self.persist_events_store = persist_events_store

def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
def _create_users_and_room(self) -> Tuple[str, str, str, str, str]:
"""
Creates two users and a shared room.

def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)

def test_count_aggregation(self) -> None:
Returns:
Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID).
"""
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
Expand All @@ -65,6 +57,70 @@ def test_count_aggregation(self) -> None:
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)

return user_id, token, other_id, other_token, room_id

def test_get_unread_push_actions_for_user_in_range(self) -> None:
"""Test getting unread push actions for HTTP and email pushers."""
user_id, token, _, other_token, room_id = self._create_users_and_room()

# Create two events, one of which is a highlight.
self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": "msg"},
tok=other_token,
)
event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": user_id},
tok=other_token,
)["event_id"]

# Fetch unread actions for HTTP pushers.
http_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
user_id, 0, 1000, 20
)
)
self.assertEqual(2, len(http_actions))

# Fetch unread actions for email pushers.
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
user_id, 0, 1000, 20
)
)
self.assertEqual(2, len(email_actions))

# Send a receipt, which should clear any actions.
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=None,
data={},
)
)
http_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
user_id, 0, 1000, 20
)
)
self.assertEqual([], http_actions)
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
user_id, 0, 1000, 20
)
)
self.assertEqual([], email_actions)

def test_count_aggregation(self) -> None:
# Create a user to receive notifications and send receipts.
user_id, token, _, other_token, room_id = self._create_users_and_room()

last_event_id: str

def _assert_counts(noitf_count: int, highlight_count: int) -> None:
Expand Down