Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implements part of MSC 3944 by dropping cancelled&duplicated `m.room_…
Browse files Browse the repository at this point in the history
…key_request`
  • Loading branch information
Mathieu Velten committed Jun 27, 2023
1 parent 14c1bfd commit e25c15e
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 8 deletions.
1 change: 1 addition & 0 deletions changelog.d/15842.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implements bullets 1 and 2 of [MSC 3944](https://github.com/matrix-org/matrix-spec-proposals/pull/3944) related to dropping cancelled and duplicated `m.room_key_request`.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.msc4010_push_rules_account_data = experimental.get(
"msc4010_push_rules_account_data", False
)

# MSC3944: Dropping stale send-to-device messages
self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False)
57 changes: 50 additions & 7 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from typing import TYPE_CHECKING, Any, Dict

Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(self, hs: "HomeServer"):
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
)

self._msc3944_enabled = hs.config.experimental.msc3944_enabled

async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
"""
Handle receiving to-device messages from remote homeservers.
Expand Down Expand Up @@ -220,7 +223,7 @@ async def send_device_message(

set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
local_messages = {}
local_messages: Dict[str, Dict[str, JsonDict]] = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# add an opentracing log entry for each message
Expand Down Expand Up @@ -255,16 +258,56 @@ async def send_device_message(

# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
for device_id, message_content in by_device.items():
# Drop any previous identical (same request_id and requesting_device_id)
# room_key_request, ignoring the action property when comparing.
# This handles dropping previous identical and cancelled requests.
if (
self._msc3944_enabled
and message_type == ToDeviceEventTypes.RoomKeyRequest
and user_id == sender_user_id
):
req_id = message_content.get("request_id")
requesting_device_id = message_content.get(
"requesting_device_id"
)
if req_id and requesting_device_id:
previous_request_deleted = False
for (
stream_id,
message_json,
) in await self.store.get_all_device_messages(
user_id, device_id
):
orig_message = json.loads(message_json)
if (
orig_message["type"]
== ToDeviceEventTypes.RoomKeyRequest
):
content = orig_message.get("content", {})
if (
content.get("request_id") == req_id
and content.get("requesting_device_id")
== requesting_device_id
):
if await self.store.delete_device_message(
stream_id
):
previous_request_deleted = True

if (
message_content.get("action") == "request_cancellation"
and previous_request_deleted
):
# Do not store the cancellation since we deleted the matching
# request(s) before it reaches the device.
continue
message = {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
local_messages.setdefault(user_id, {})[device_id] = message
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
Expand Down
41 changes: 41 additions & 0 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from synapse.api.constants import EventContentFields
from synapse.api.errors import StoreError
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import (
SynapseTags,
Expand Down Expand Up @@ -891,6 +892,46 @@ def _add_messages_to_local_device_inbox_txn(
],
)

async def delete_device_message(self, stream_id: int) -> bool:
"""Delete a specific device message from the message inbox.
Args:
stream_id: the stream ID identifying the message.
Returns:
True if the message has been deleted, False if it didn't exist.
"""
try:
await self.db_pool.simple_delete_one(
"device_inbox",
keyvalues={"stream_id": stream_id},
desc="delete_device_message",
)
except StoreError:
# Deletion failed because device message does not exist
return False
return True

async def get_all_device_messages(
self,
user_id: str,
device_id: str,
) -> List[Tuple[int, str]]:
"""Get all device messages in the inbox from a specific device.
Args:
user_id: the user ID of the device we want to query.
device_id: the device ID of the device we want to query.
Returns:
A list of (stream ID, message content) tuples.
"""
rows = await self.db_pool.simple_select_list(
table="device_inbox",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("stream_id", "message_json"),
desc="get_all_device_messages",
)
return [(r["stream_id"], r["message_json"]) for r in rows]


class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
Expand Down
121 changes: 120 additions & 1 deletion tests/handlers/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

from twisted.test.proto_helpers import MemoryReactor

import synapse
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
from synapse.types import JsonDict, create_requester
from synapse.util import Clock

from tests import unittest
Expand All @@ -37,6 +38,11 @@


class DeviceTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
]

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver(
Expand All @@ -47,6 +53,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.msg_handler = hs.get_device_message_handler()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
return hs

Expand Down Expand Up @@ -398,6 +406,117 @@ def test_on_federation_query_user_devices_appservice(self) -> None:
],
)

@override_config({"experimental_features": {"msc3944_enabled": True}})
def test_duplicated_and_cancelled_room_key_request(self) -> None:
myuser = self.register_user("myuser", "pass")
self.login("myuser", "pass", "device")
self.login("myuser", "pass", "device2")
self.login("myuser", "pass", "device3")

requester = requester = create_requester(myuser)

from_token = self.event_sources.get_current_token()

# This room_key_request is for device3 and should not be deleted.
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device3": {
"action": "request",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)

for _ in range(0, 2):
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device2": {
"action": "request",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)

to_token = self.event_sources.get_current_token()

# Test that if we queue 2 identical room_key_request,
# only one is delivered to the device.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device2",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)

# room_key_request for device3 should still be around.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device3",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)

self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device2": {
"action": "request_cancellation",
"request_id": "request_id",
"requesting_device_id": "device",
}
}
},
)
)

to_token = self.event_sources.get_current_token()

# Test that if we cancel a room_key_request, both previous matching
# requests and the cancelled request are not delivered to the device.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device2",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 0)

# room_key_request for device3 should still be around.
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device3",
from_token.to_device_key,
to_token.to_device_key,
)
)
self.assertEqual(len(res[0]), 1)


class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
Expand Down

0 comments on commit e25c15e

Please sign in to comment.