Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clarify that a method returns only unthreaded receipts #13937

Merged
merged 4 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13937.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
12 changes: 3 additions & 9 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,11 @@ def _get_unread_counts_by_receipt_txn(
user_id: str,
) -> NotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
user_id,
room_id,
receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)

if result:
Expand Down Expand Up @@ -574,10 +571,7 @@ def _get_receipts_by_room_txn(
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)

sql = f"""
Expand Down
36 changes: 5 additions & 31 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,48 +135,21 @@ def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()

async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]:
"""
Fetch the event ID for the latest receipt in a room with one of the given receipt types.

Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.

Returns:
The latest receipt, if one exists.
"""
result = await self.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.get_last_receipt_for_user_txn,
user_id,
room_id,
receipt_types,
)
if not result:
return None

event_id, _ = result
return event_id

def get_last_receipt_for_user_txn(
def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
user_id: str,
room_id: str,
receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]:
"""
Fetch the event ID and stream_ordering for the latest receipt in a room
with one of the given receipt types.
Fetch the event ID and stream_ordering for the latest unthreaded receipt
in a room with one of the given receipt types.

Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.
receipt_types: The receipt types to fetch.

Returns:
The event ID and stream ordering of the latest receipt, if one exists.
Expand All @@ -193,6 +166,7 @@ def get_last_receipt_for_user_txn(
WHERE {clause}
AND user_id = ?
AND room_id = ?
AND thread_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""
Expand Down
74 changes: 38 additions & 36 deletions tests/storage/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Optional

from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
Expand Down Expand Up @@ -84,6 +85,33 @@ def prepare(self, reactor, clock, homeserver) -> None:
)
)

def get_last_unthreaded_receipt(
self, receipt_types: Collection[str], room_id: Optional[str] = None
) -> Optional[str]:
"""
Fetch the event ID for the latest unthreaded receipt in the test room for the test user.

Args:
receipt_types: The receipt types to fetch.

Returns:
The latest receipt, if one exists.
"""
result = self.get_success(
self.store.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.store.get_last_unthreaded_receipt_for_user_txn,
OUR_USER_ID,
room_id or self.room_id1,
receipt_types,
)
)
if not result:
return None

event_id, _ = result
return event_id

def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
Expand All @@ -107,16 +135,10 @@ def test_return_empty_with_no_data(self) -> None:
)
self.assertEqual(res, {})

res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)

self.assertEqual(res, None)

def test_get_receipts_for_user(self) -> None:
Expand Down Expand Up @@ -228,29 +250,17 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
)

# Test we get the latest event when we want both private and public receipts
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)

# Test we get the older event when we want only public receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)

# Test we get the latest event when we want only the private receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)

# Test receipt updating
Expand All @@ -259,11 +269,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)

# Send some events into the second room
Expand All @@ -282,11 +288,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
{},
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
)
self.assertEqual(res, event2_1_id)