Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache empty responses from /user/devices #11587

Merged
merged 20 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11587.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where Synapse wouldn't cache the fact that a remote user has no devices.
8 changes: 6 additions & 2 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,13 @@ async def _check_for_unknown_devices(
)
return

# If we are tracking check that we know about the sending
# devices.
# If we are tracking, check that we know about the sending devices.
cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
if cached_devices is None:
# This means we've never requested details of the remote user's devices.
# Odd, given that we're processing an update for the devices---but carry on
# as if we've not heard of the device.
cached_devices = {}

unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices:
Expand Down
2 changes: 2 additions & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,8 @@ async def _process_received_pdu(
sender_key = event.content.get("sender_key")

cached_devices = await self._store.get_cached_devices_for_user(event.sender)
if cached_devices is None:
cached_devices = {}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

resync = False # Whether we should resync device lists.

Expand Down
64 changes: 53 additions & 11 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ async def get_user_devices_from_cache(
"""Get the devices (and keys if any) for remote users from the cache.

Args:
query_list: List of (user_id, device_ids), if device_ids is
query_list: List of (user_id, device_ids) pairs. If device_ids is
falsey then return all device ids for that user.

Returns:
Expand Down Expand Up @@ -555,7 +555,10 @@ async def get_user_devices_from_cache(
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
user_devices = await self.get_cached_devices_for_user(user_id)
if user_devices is None:
user_devices = {}
results[user_id] = user_devices

set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
Expand All @@ -573,16 +576,53 @@ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDic
return db_to_json(content)

@cached()
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
async def get_cached_devices_for_user(
self, user_id: str
) -> Optional[Dict[str, JsonDict]]:
"""Retrieve the most recent cached devices data.

We can be in three states, depending on the latest response from the remote
homeserver.

- We could never requested devices for this user. In this case, return `None`.
- We could have requested devices for this user, only to be told they don't
have any devices. In this case, return an empty dictionary.
- Otherwise, we've cached details of 1 or more devices for this user. Return
a a dictionary from device id to the device data.
"""
return await self.db_pool.runInteraction(
"get_cached_devices_for_user",
self._get_cached_devices_for_user_txn,
user_id,
)
return {
device["device_id"]: db_to_json(device["content"]) for device in devices
}

def _get_cached_devices_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Optional[Dict[str, JsonDict]]:
# Four cases:
# 1. No stream id, no cached devices. Query yields no rows. Return None.
# 2. No stream id, >= 1 cached devices. Invalid state. Query will yield no rows.
# return None.
# 3. Stream id, no cached devices. Return empty dict. Query returns one row
# (non-NULL stream_id, NULL, NULL). Return empty dict.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# 4. Stream id, >= 1 cached devices. Query return 1 or more row
# (non-NULL stream_id, non-NULL device_id, non-NULL content). Return dict.
query = """
SELECT stream_id, device_id, content
FROM device_lists_remote_extremeties
clokep marked this conversation as resolved.
Show resolved Hide resolved
LEFT JOIN device_lists_remote_cache USING(user_id)
WHERE user_id = ?
"""
txn.execute(query, (user_id,))
devices: List[Tuple[str, Optional[str], Optional[str]]] = txn.fetchall()
if not devices:
return None
elif devices[0][1] is None:
return {}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
else:
return {
device_id: db_to_json(content) for (_, device_id, content) in devices
}

async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
Expand Down Expand Up @@ -1316,6 +1356,7 @@ def _update_remote_device_list_cache_entry_txn(
content: JsonDict,
stream_id: str,
) -> None:
"""Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
Expand Down Expand Up @@ -1375,6 +1416,7 @@ async def update_remote_device_list_cache(
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
"""Replace all cached devices for this user with the given list of devices."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
Expand Down
65 changes: 65 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from unittest import mock
from unittest.mock import patch

from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer

from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.types import JsonDict

from tests import unittest

Expand Down Expand Up @@ -765,6 +769,8 @@ def test_query_devices_remote_sync(self):
remote_user_id = "@test:other"
local_user_id = "@test:test"

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
Expand Down Expand Up @@ -831,3 +837,62 @@ def test_query_devices_remote_sync(self):
}
},
)

@parameterized.expand(
[([],), ([{"device_id": "device_1"}, {"device_id": "device_2"}],)]
clokep marked this conversation as resolved.
Show resolved Hide resolved
)
def test_query_all_devices_caches_result(self, response_devices: List[JsonDict]):
"""Test that requests for all of a remote user's devices are cached.

We do this by asserting that only one call over federation was made.
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}}
response_body = {
"devices": response_devices,
"user_id": remote_user_id,
"stream_id": "remote_stream_id_1234",
}

e2e_handler = self.hs.get_e2e_keys_handler()

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
mock_get_rooms = patch.object(
self.store,
"get_rooms_for_user",
return_value=defer.succeed(["some_room_id"]),
)
mock_request = patch.object(
self.hs.get_federation_client(),
"query_user_devices",
return_value=defer.succeed(response_body),
)

with mock_get_rooms, mock_request as mocked_federation_request:
# Make the first query.
self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)

# We should have made a federation request to do so.
mocked_federation_request.assert_called_once()

# Repeat the query.
self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)

# We should not have made a second federation request.
mocked_federation_request.assert_called_once()
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved