Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow clients to upload one-time-keys with new sigs #2206

Merged
merged 3 commits into from
May 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@ def on_claim_client_keys(self, origin, content):
key_id: json.loads(json_bytes)
}

logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems()
for device_id, device_keys in user_keys.iteritems()
for key_id, _ in device_keys.iteritems()
)),
)

defer.returnValue({"one_time_keys": json_result})

@defer.inlineCallbacks
Expand Down
86 changes: 70 additions & 16 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,7 +145,7 @@ def do_remote_query(destination):
"status": 503, "message": e.message
}

yield preserve_context_over_deferred(defer.gatherResults([
yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries_not_in_cache
]))
Expand Down Expand Up @@ -257,11 +257,21 @@ def claim_client_keys(destination):
"status": 503, "message": e.message
}

yield preserve_context_over_deferred(defer.gatherResults([
yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))

logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems()
for device_id, device_keys in user_keys.iteritems()
for key_id, _ in device_keys.iteritems()
)),
)

defer.returnValue({
"one_time_keys": json_result,
"failures": failures
Expand All @@ -288,19 +298,8 @@ def upload_keys_for_user(self, user_id, device_id, keys):

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))

yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
yield self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys,
)

# the device should have been registered already, but it may have been
Expand All @@ -313,3 +312,58 @@ def upload_keys_for_user(self, user_id, device_id, keys):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)

defer.returnValue({"one_time_key_counts": result})

@defer.inlineCallbacks
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
one_time_keys):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), device_id, user_id, time_now,
)

# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, key_obj
))

# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)

new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
("One time key %s:%s already exists. "
"Old key: %s; new key: %r") %
(algorithm, key_id, ex_json, key)
)
else:
new_keys.append((algorithm, key_id, encode_canonical_json(key)))

yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)


def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)

# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
return old_key == new_key

# otherwise, we strip off the 'signatures' if any, because it's legitimate
# for different upload attempts to have different signatures.
old_key.pop("signatures", None)
new_key_copy = dict(new_key)
new_key_copy.pop("signatures", None)

return old_key == new_key_copy
47 changes: 27 additions & 20 deletions synapse/storage/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
from twisted.internet import defer

from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached

from canonicaljson import encode_canonical_json
Expand Down Expand Up @@ -124,18 +123,24 @@ def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
return result

@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
"""Insert some new one time keys for a device.
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
"""Retrieve a number of one-time keys for a user

Checks if any of the keys are already inserted, if they are then check
if they match. If they don't then we raise an error.
Args:
user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
key_ids(list[str]): list of key ids (excluding algorithm) to
retrieve

Returns:
deferred resolving to Dict[(str, str), str]: map from (algorithm,
key_id) to json string for key
"""

# First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=[key_id for _, key_id, _ in key_list],
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
Expand All @@ -144,20 +149,22 @@ def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
desc="add_e2e_one_time_keys_check",
)

existing_key_map = {
defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
}

new_keys = [] # Keys that we need to insert
for algorithm, key_id, json_bytes in key_list:
ex_bytes = existing_key_map.get((algorithm, key_id), None)
if ex_bytes:
if json_bytes != ex_bytes:
raise SynapseError(
400, "One time key with key_id %r already exists" % (key_id,)
)
else:
new_keys.append((algorithm, key_id, json_bytes))
})

@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.

Args:
user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
time_now(long): insertion time to record (ms since epoch)
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
(algorithm, key_id, key json)
"""

def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to
Expand Down
132 changes: 132 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import mock
from synapse.api import errors
from twisted.internet import defer

import synapse.api.errors
Expand Down Expand Up @@ -44,3 +45,134 @@ def test_query_local_devices_no_devices(self):
local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None})
self.assertDictEqual(res, {local_user: {}})

@defer.inlineCallbacks
def test_reupload_one_time_keys(self):
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
"alg2:k2": {
"key": "key2",
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
}

res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})

# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})

@defer.inlineCallbacks
def test_change_one_time_keys(self):
"""attempts to change one-time-keys should be rejected"""

local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
"alg2:k2": {
"key": "key2",
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
}

res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})

try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass

try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass

try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {
"one_time_keys": {"alg1:k1": {"key": "key"}}
},
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass

try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {
"one_time_keys": {
"alg2:k2": {
"key": "key3",
"signatures": {"k1": "sig1"},
}
},
},
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
pass

@unittest.DEBUG
@defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
}

res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1}
})

res2 = yield self.handler.claim_one_time_keys({
"one_time_keys": {
local_user: {
device_id: "alg1"
}
}
}, timeout=None)
self.assertEqual(res2, {
"failures": {},
"one_time_keys": {
local_user: {
device_id: {
"alg1:k1": "key1"
}
}
}
})