Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add support for calling /keys/claim for appservices.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 27, 2023
1 parent d8a0d93 commit 6cd7f9f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
56 changes: 56 additions & 0 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,62 @@ async def push_bulk(
failed_transactions_counter.labels(service.id).inc()
return False

async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from an application service.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
A copy of the input which has not been fulfilled because the
appservice doesn't support this endpoint or has not returned
data for that tuple.
"""
if service.url is None:
return {}, query

# This is required by the configuration.
assert service.hs_token is not None

# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)

uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
response = await self.post_json_get_json(
uri,
body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
except CodeMessageException as e:
# The appservice doesn't support this endpoint.
if e.code == 404 or e.code == 405:
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
except Exception as ex:
logger.warning("claim_keys to %s threw exception %s", uri, ex)
return {}, query

# Check if the appservice fulfilled all of the queried user/device/algorithms
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = [
(user_id, device, algorithm)
for user_id, device, algorithm in query
if algorithm not in response.get(user_id, {}).get(device, [])
]

return response, missing

def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
Expand Down
59 changes: 59 additions & 0 deletions tests/appservice/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,62 @@ async def get_json(
)
self.assertEqual(self.request_url, URL_LOCATION)
self.assertEqual(result, SUCCESS_RESULT_LOCATION)

def test_claim_keys(self) -> None:
"""
Tests that 3pe queries to the appservice are authenticated
with the appservice's token.
"""

RESPONSE: JsonDict = {
"@alice:example.org": {
"DEVICE_1": {
"signed_curve25519:AAAAHg": {
# We don't really care about the content of the keys,
# they get passed back transparently.
},
"signed_curve25519:BBBBHg": {},
},
"DEVICE_2": {"signed_curve25519:CCCCHg": {}},
},
}

async def post_json_get_json(
uri: str,
post_json: Any,
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> JsonDict:
# Ensure the access token is passed as both a header and query arg.
if not headers.get("Authorization"):
raise RuntimeError("Access token not provided")

self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
return RESPONSE

# We assign to a method, which mypy doesn't like.
self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment]

MISSING_KEYS = [
# Known user, known device, missing algorithm.
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
# Known user, missing device.
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
# Unknown user.
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
]

claimed_keys, missing = self.get_success(
self.api.claim_client_keys(
self.service,
[
# Found devices
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
]
+ MISSING_KEYS,
)
)

self.assertEqual(claimed_keys, RESPONSE)
self.assertEqual(missing, MISSING_KEYS)

0 comments on commit 6cd7f9f

Please sign in to comment.