Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cross-signing [4/4] -- federation edition #5727

Merged
merged 18 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5727.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add federation support for cross-signing.
4 changes: 2 additions & 2 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ def _get_device_update_edus(self, limit):
Edu(
origin=self._server_name,
destination=self._destination,
edu_type="m.device_list_update",
edu_type=edu_type,
content=content,
)
for content in results
for (edu_type, content) in results
]

assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
Expand Down
13 changes: 12 additions & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,18 @@ def notify_user_signature_update(self, from_user_id, user_ids):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)

return {
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
"master_key": master_key,
"self_signing_key": self_signing_key,
}

@defer.inlineCallbacks
def user_left_room(self, user, room_id):
Expand Down
137 changes: 128 additions & 9 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

logger = logging.getLogger(__name__)
Expand All @@ -49,10 +51,19 @@ def __init__(self, hs):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

self._edu_updater = SigningKeyEduUpdater(hs, self)

federation_registry = hs.get_federation_registry()

# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
federation_registry.register_edu_handler(
"org.matrix.signing_key_update",
self._edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
hs.get_federation_registry().register_query_handler(
federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)

Expand Down Expand Up @@ -207,13 +218,15 @@ def do_remote_query(destination):
if user_id in destination_query:
results[user_id] = keys

for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key

for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key

except Exception as e:
failure = _exception_to_failure(e)
Expand Down Expand Up @@ -251,7 +264,7 @@ def get_cross_signing_keys_from_cache(self, query, from_user_id):

Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
(master|self_signing|user_signing) -> user_id -> key
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
Expand Down Expand Up @@ -343,7 +356,16 @@ def on_federation_query_client_keys(self, query_body):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
return {"device_keys": res}
ret = {"device_keys": res}

# add in the cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
device_keys_query, None
)

ret.update(cross_signing_keys)

return ret

@trace
@defer.inlineCallbacks
Expand Down Expand Up @@ -1047,3 +1069,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()


class SigningKeyEduUpdater(object):
"Handles incoming signing key updates from federation and updates the DB"

def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler

self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

# user_id -> list of updates waiting to be handled.
self._pending_updates = {}

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)

@defer.inlineCallbacks
def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.

Args:
origin (string): the server that sent the EDU
edu_content (dict): the contents of the EDU
"""

user_id = edu_content.pop("user_id")
master_key = edu_content.pop("master_key", None)
self_signing_key = edu_content.pop("self_signing_key", None)

if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return

room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return

self._pending_updates.setdefault(user_id, []).append(
(master_key, self_signing_key, edu_content)
)

yield self._handle_signing_key_updates(user_id)

@defer.inlineCallbacks
def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.

Args:
user_id (string): the user whose updates we are processing
"""

device_handler = self.e2e_keys_handler.device_handler

with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
return

device_ids = []

logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key, edu_content in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
device_id = get_verify_key_from_cross_signing_key(master_key)[
1
].version
device_ids.append(device_id)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
device_id = get_verify_key_from_cross_signing_key(self_signing_key)[
1
].version
device_ids.append(device_id)

yield device_handler.notify_device_update(user_id, device_ids)
88 changes: 78 additions & 10 deletions synapse/storage/data_stores/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList

Expand Down Expand Up @@ -94,9 +95,10 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers

Returns:
Deferred[tuple[int, list[dict]]]:
Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
response), and the list of updates
response), and the list of updates, where each update is a pair of EDU
type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()

Expand Down Expand Up @@ -129,6 +131,33 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
if not updates:
return now_stream_id, []

# get the cross-signing keys of the users the list
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
master_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version,
}

cross_signing_key = yield self.get_e2e_cross_signing_key(
user, "self_signing"
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version,
}

# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
Expand Down Expand Up @@ -158,6 +187,16 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
# Stop processing updates
break

# skip over cross-signing keys
if (
update[0] in master_key_by_user
and update[1] == master_key_by_user[update[0]]["pubkey"]
) or (
update[0] in master_key_by_user
and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
):
continue

key = (update[0], update[1])

update_context = update[3]
Expand All @@ -172,16 +211,40 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
# means that there are more than limit updates all of which have the same
# steam_id.

# figure out which cross-signing keys were changed by intersecting the
# update list with the master/self-signing key by user maps
cross_signing_keys_by_user = {}
for user_id, device_id, stream, _opentracing_context in updates:
if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif device_id == self_signing_key_by_user.get(user_id, {}).get(
"pubkey", None
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]

cross_signing_results = []

# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
cross_signing_results.append(("org.matrix.signing_key_update", result))

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
if not query_map and not cross_signing_results:
return stream_id_cutoff, []

results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
results.extend(cross_signing_results)

return now_stream_id, results

Expand All @@ -200,6 +263,7 @@ def _get_devices_by_remote_txn(
Returns:
List: List of device updates
"""
# get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
Expand All @@ -225,12 +289,16 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m
List[Dict]: List of objects representing an device update EDU

"""
devices = yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
devices = (
yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)
if query_map
else {}
)

results = []
Expand Down Expand Up @@ -262,7 +330,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m
else:
result["deleted"] = True

results.append(result)
results.append(("m.device_list_update", result))

return results

Expand Down
Loading