Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224)
Browse files Browse the repository at this point in the history
this makes it a bit clearer what's going on.
  • Loading branch information
richvdh authored Sep 2, 2020
1 parent b939251 commit abeab96
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
1 change: 1 addition & 0 deletions changelog.d/8224.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,17 @@ async def _get_device_update_edus_by_remote(
prev_id = stream_id

if device is not None:
key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
Expand Down
52 changes: 36 additions & 16 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json

from twisted.enterprise.adbapi import Connection
Expand All @@ -33,6 +34,21 @@
from synapse.handlers.e2e_keys import SignatureListItem


@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""

display_name = attr.ib(type=Optional[str])

# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])

# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)


class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
Expand Down Expand Up @@ -61,17 +77,17 @@ def _get_e2e_device_keys_for_federation_query_txn(
for device_id, device in user_devices.items():
result = {"device_id": device_id}

key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name

Expand Down Expand Up @@ -109,13 +125,13 @@ async def get_e2e_device_keys_for_cs_api(
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r = db_to_json(device_info.key_json)
r["unsigned"] = {}
display_name = device_info["device_display_name"]
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
if device_info.signatures:
for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
Expand All @@ -126,7 +142,7 @@ async def get_e2e_device_keys_for_cs_api(
@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[Dict]]]:
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)

Expand Down Expand Up @@ -161,7 +177,7 @@ def _get_e2e_device_keys_and_signatures_txn(

sql = (
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
Expand All @@ -172,13 +188,14 @@ def _get_e2e_device_keys_and_signatures_txn(
)

txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)

result = {}
for row in rows:
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
)

if include_deleted_devices:
for user_id, device_id in deleted_devices:
Expand Down Expand Up @@ -209,7 +226,10 @@ def _get_e2e_device_keys_and_signatures_txn(
# note that target_device_result will be None for deleted devices.
continue

target_device_signatures = target_device_result.setdefault("signatures", {})
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
Expand Down

0 comments on commit abeab96

Please sign in to comment.