Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache empty responses from /user/devices #11587

Merged
merged 20 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11587.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices.
10 changes: 9 additions & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,16 @@ async def user_device_resync(
devices = []
ignore_devices = True
else:
sid = await self.store.get_device_list_last_stream_id_for_remote(user_id)
clokep marked this conversation as resolved.
Show resolved Hide resolved
cached_devices = await self.store.get_cached_devices_for_user(user_id)
if cached_devices == {d["device_id"]: d for d in devices}:

# If this is the first time we've queried this user, then `sid is None` and
# `cached_devices` will be empty. If the remote user has no devices (i.e.
# `devices` is empty), we should cache this fact. For this reason, we skip
# only if `sid is not None`.
clokep marked this conversation as resolved.
Show resolved Hide resolved
if sid is not None and cached_devices == {
d["device_id"]: d for d in devices
}:
logging.info(
"Skipping device list resync for %s, as our cache matches already",
user_id,
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def _get_all_device_list_changes_for_remotes(txn):
@cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
Expand All @@ -729,7 +729,9 @@ async def get_device_list_last_stream_id_for_remote(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
Expand Down Expand Up @@ -1316,6 +1318,7 @@ def _update_remote_device_list_cache_entry_txn(
content: JsonDict,
stream_id: str,
) -> None:
"""Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
Expand Down Expand Up @@ -1375,6 +1378,7 @@ async def update_remote_device_list_cache(
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
"""Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
Expand Down
92 changes: 92 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from unittest import mock
from unittest.mock import patch

from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer
Expand All @@ -23,6 +26,7 @@
from synapse.api.errors import Codes, SynapseError

from tests import unittest
from tests.test_utils import make_awaitable


class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -765,6 +769,8 @@ def test_query_devices_remote_sync(self):
remote_user_id = "@test:other"
local_user_id = "@test:test"

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
Expand Down Expand Up @@ -831,3 +837,89 @@ def test_query_devices_remote_sync(self):
}
},
)

@parameterized.expand(
[
# The remote homeserver's response indicates that this user has 0/1/2 devices.
([],),
(["device_1"],),
(["device_1", "device_2"],),
]
)
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
"""Test that requests for all of a remote user's devices are cached.

We do this by asserting that only one call over federation was made, and that
the two queries to the local homeserver produce the same response.
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}}

response_devices = [
{
"device_id": device_id,
"keys": {
"algorithms": ["dummy"],
"device_id": device_id,
"keys": {f"dummy:{device_id}": "dummy"},
"signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
"unsigned": {},
"user_id": "@test:other",
},
}
for device_id in device_ids
]

response_body = {
"devices": response_devices,
"user_id": remote_user_id,
"stream_id": 12345, # an integer, according to the spec
}

e2e_handler = self.hs.get_e2e_keys_handler()

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
mock_get_rooms = patch.object(
self.store,
"get_rooms_for_user",
return_value=make_awaitable("some_room_id"),
)
mock_request = patch.object(
self.hs.get_federation_client(),
"query_user_devices",
return_value=make_awaitable(response_body),
)

with mock_get_rooms, mock_request as mocked_federation_request:
# Make the first query.
response_1 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)

# We should have made a federation request to do so.
mocked_federation_request.assert_called_once()

# Repeat the query.
response_2 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)

# We should not have made a second federation request.
mocked_federation_request.assert_called_once()
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

# The two requests to the local homeserver should be identical, and should
# not indicate any errors.
self.assertEqual(response_1, response_2)
self.assertEqual(response_1["failures"], {})