From 17377a0d81af89b05e1a32d94443e343967cf511 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Mar 2023 15:27:55 -0400 Subject: [PATCH] Implement MSC3983 to proxy /keys/claim queries to appservices. --- changelog.d/15314.feature | 1 + synapse/config/experimental.py | 5 +++ synapse/handlers/appservice.py | 72 ++++++++++++++++++++++++++++++- synapse/handlers/e2e_keys.py | 25 +++++++++-- tests/handlers/test_e2e_keys.py | 76 ++++++++++++++++++++++++++++++++- 5 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 changelog.d/15314.feature diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature new file mode 100644 index 000000000000..68b289b0cc7c --- /dev/null +++ b/changelog.d/15314.feature @@ -0,0 +1 @@ +Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 99dcd27c7426..53e6fc2b54d6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -74,6 +74,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc3202_transaction_extensions", False ) + # MSC3983: Proxying OTK claim requests to exclusive ASes. + self.msc3983_appservice_otk_claims: bool = experimental.get( + "msc3983_appservice_otk_claims", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) # Synapse will always serve partial state responses to requests using the stable # query parameter `omit_members`. If this flag is set, Synapse will also serve diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ec3ab968e925..51685a2edb70 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from prometheus_client import Counter @@ -829,3 +838,64 @@ async def _check_user_exists(self, user_id: str) -> bool: if unknown_user: return await self.query_user_exists(user_id) return True + + async def claim_e2e_one_time_keys( + self, query: Iterable[Tuple[str, str, str]] + ) -> Tuple[ + Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] + ]: + """Claim one time keys from application services. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + + A copy of the input which has not been fulfilled (either because + they are not appservice users or the appservice does not support + providing OTKs). + """ + services = self.store.get_app_services() + + # Partition the users by appservice. + query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + missing = [] + for user_id, device, algorithm in query: + if not self.store.get_if_app_services_interested_in_user(user_id): + missing.append((user_id, device, algorithm)) + continue + + # Find the associated appservice. + for service in services: + if service.is_exclusive_user(user_id): + query_by_appservice.setdefault(service.id, []).append( + (user_id, device, algorithm) + ) + continue + + # Query each service in parallel. + results = await make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.claim_client_keys, + # We know this must be an app service. + self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + service_query, + ) + for service_id, service_query in query_by_appservice.items() + ], + consumeErrors=True, + ) + ) + + # Patch together the results. + claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + for success, result in results: + if success: + claimed_keys.append(result[0]) + missing.extend(result[1]) + + return claimed_keys, missing diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8d85e6856503..91f49c7e27f1 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple @@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() + self._appservice_handler = hs.get_application_service_handler() self.is_mine = hs.is_mine self.clock = hs.get_clock() @@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"): max_count=10, ) + self._query_appservices_for_otks = ( + hs.config.experimental.msc3983_appservice_otk_claims + ) + @trace @cancellable async def query_devices( @@ -548,7 +552,8 @@ async def claim_local_one_time_keys( """Claim one time keys for local users. 1. Attempt to claim OTKs from the database. - 2. Attempt to fetch fallback keys from the database. + 2. Ask application services if they provide OTKs. + 3. Attempt to fetch fallback keys from the database. Args: local_query: An iterable of tuples of (user ID, device ID, algorithm). @@ -559,10 +564,21 @@ async def claim_local_one_time_keys( otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + # If the application services have not provided any keys via the C-S + # API, query it directly. + if self._query_appservices_for_otks: + # Query the appservices for any OTKs. + ( + appservice_results, + not_found, + ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + else: + appservice_results = [] + # For any *still* remaining users, try fall-back keys. fallback_results = await self.store.claim_e2e_fallback_keys(not_found) - return (otk_results, fallback_results) + return (otk_results, *appservice_results, fallback_results) @trace async def claim_one_time_keys( @@ -593,6 +609,9 @@ async def claim_one_time_keys( for key_id, key in keys.items(): json_result.setdefault(user_id, {})[device_id] = {key_id: key} + # Remote failures. + failures: Dict[str, JsonDict] = {} + @trace async def claim_client_keys(destination: str) -> None: set_tag("destination", destination) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6b4cba65d069..4ff04fc66bb7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -23,18 +23,24 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer +from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.types import JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=mock.Mock()) + self.appservice_api = mock.Mock() + return self.setup_test_homeserver( + federation_client=mock.Mock(), application_service_api=self.appservice_api + ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() @@ -941,3 +947,71 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> Non # The two requests to the local homeserver should be identical. self.assertEqual(response_1, response_2) + + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + device_id_2 = "abc" + otk = {"alg1:k2": "key2"} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response, but only for device 2. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ) + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # claiming an OTK when no OTKs are available should ask the appservice, then + # query the fallback keys. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + { + "one_time_keys": { + local_user: {device_id_1: "alg1", device_id_2: "alg1"} + } + }, + timeout=None, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: fallback_key, device_id_2: otk} + }, + }, + )