Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move and rename get_devices_with_keys_by_user #8204

Merged
merged 6 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8204.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
4 changes: 3 additions & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ async def get_user_ids_changed(self, user_id, from_token):
return result

async def on_federation_query_user_devices(self, user_id):
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
Expand Down
3 changes: 3 additions & 0 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"DeviceListFederationStreamChangeCache", device_list_max
)

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
Expand Down
3 changes: 3 additions & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
Expand Down
52 changes: 5 additions & 47 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

Expand Down Expand Up @@ -101,7 +102,7 @@ async def get_device_updates_by_remote(
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
now_stream_id = self._device_list_id_gen.get_current_token()
now_stream_id = self.get_device_stream_token()

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
Expand Down Expand Up @@ -412,8 +413,10 @@ def _add_user_signature_change_txn(
},
)

@abc.abstractmethod
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
"""Get the current stream id from the _device_list_id_gen"""
...

@trace
async def get_user_devices_from_cache(
Expand Down Expand Up @@ -481,51 +484,6 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
device["device_id"]: db_to_json(device["content"]) for device in devices
}

def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user

Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
)

def _get_devices_with_keys_by_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token()

devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])

if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}

key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name

results.append(result)

return now_stream_id, results

return now_stream_id, []

async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
Expand Down
53 changes: 52 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

from canonicaljson import encode_canonical_json
Expand All @@ -22,7 +23,7 @@

from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -33,6 +34,51 @@


class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
"""Get all devices (with any device keys) for a user

Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
)

def _get_e2e_device_keys_for_federation_query_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self.get_device_stream_token()

devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])

if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}

key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name

results.append(result)

return now_stream_id, results

return now_stream_id, []

@trace
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
Expand Down Expand Up @@ -533,6 +579,11 @@ def _get_all_user_signature_changes_for_remotes_txn(txn):
_get_all_user_signature_changes_for_remotes_txn,
)

@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
Expand Down