Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. #11730

Merged
merged 17 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11730.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
94 changes: 78 additions & 16 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def get_devices_by_auth_provider_session_id(
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
) -> Tuple[int, List[Tuple[str, dict]]]:
) -> Tuple[int, List[Tuple[str, JsonDict]]]:
"""Get a stream of device updates to send to the given remote server.

Args:
Expand All @@ -200,9 +200,10 @@ async def get_device_updates_by_remote(
limit: Maximum number of device updates to return

Returns:
A mapping from the current stream id (ie, the stream id of the last
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
- The current stream id (i.e. the stream id of the last update included
in the response); and
- The list of updates, where each update is a pair of EDU type and
EDU contents.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""
now_stream_id = self.get_device_stream_token()

Expand All @@ -221,6 +222,9 @@ async def get_device_updates_by_remote(
limit,
)

# We need to ensure `updates` doesn't grow too big.
# Currently: `len(updates) <= limit`.

# Return an empty list if there are no updates
if not updates:
return now_stream_id, []
Expand Down Expand Up @@ -277,40 +281,88 @@ async def get_device_updates_by_remote(
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if (
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
# gives rise to two device updates in the result, so those cost twice
# as much (and are the whole reason we need to separately calculate
# the budget; we know len(updates) <= limit otherwise!)
# N.B. len() on dicts is cheap since they store their size.
remaining_length_budget = limit - (
len(query_map) + 2 * len(cross_signing_keys_by_user)
)
assert remaining_length_budget >= 0

is_master_key_update = (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif (
)
is_self_signing_key_update = (
user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"]
)

is_cross_signing_key_update = (
is_master_key_update or is_self_signing_key_update
)

if (
is_cross_signing_key_update
and user_id not in cross_signing_keys_by_user
):
# This will give rise to 2 device updates.
# If we don't have the budget, stop here!
if remaining_length_budget < 2:
break
Comment on lines +314 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we drop these edus here (or the other one on +329), what ensures they'll be picked up the next time we come to work out which device updates need sending over federation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only advance the stream position as far as the loop goes (see a little bit down). I added a comment to it to make it more noticeable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I think I see. We tell the caller to update the stream id to last_processed_stream_id via return value. This only gets bumped at the end of a the loop body over updates.


if is_master_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif is_self_signing_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
else:
key = (user_id, device_id)

if key not in query_map and remaining_length_budget < 1:
# We don't have space for a new entry
break

previous_update_stream_id, _ = query_map.get(key, (0, None))

if update_stream_id > previous_update_stream_id:
# FIXME If this overwrites an older update, this discards the
# previous OpenTracing context.
# It might make it harder to track down issues using OpenTracing.
# If there's a good reason why it doesn't matter, a comment here
# about that would not hurt.
query_map[key] = (update_stream_id, update_context)

# As this update has been added to the response, advance the stream
# position.
last_processed_stream_id = update_stream_id

# In the worst case scenario, each update is for a distinct user and is
# added either to the query_map or to cross_signing_keys_by_user,
# but not both:
# len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# so len(query_map) + len(cross_signing_keys_by_user) <= limit.

results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)

# add the updated cross-signing keys to the results list
# len(results) <= len(query_map) here,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# so len(results) + len(cross_signing_keys_by_user) <= limit.
reivilibre marked this conversation as resolved.
Show resolved Hide resolved

# Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
results.append(("m.signing_key_update", result))
# also send the unstable version
# FIXME: remove this when enough servers have upgraded
# and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result))

return last_processed_stream_id, results
Expand All @@ -322,7 +374,7 @@ def _get_device_updates_by_remote_txn(
from_stream_id: int,
now_stream_id: int,
limit: int,
):
) -> List[Tuple[str, str, int, Optional[str]]]:
"""Return device update information for a given remote destination

Args:
Expand All @@ -333,7 +385,11 @@ def _get_device_updates_by_remote_txn(
limit: Maximum number of device updates to return

Returns:
List: List of device updates
List: List of device update tuples:
- user_id
- device_id
- stream_id
- opentracing_context
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""
# get the list of device updates that need to be sent
sql = """
Expand All @@ -357,15 +413,21 @@ async def _get_device_update_edus_by_remote(
Args:
destination: The host the device updates are intended for
from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
query_map: Dictionary mapping (user_id, device_id) to
(update stream_id, the relevant json-encoded opentracing context)

Returns:
List of objects representing an device update EDU
List of objects representing a device update EDU.

Postconditions:
The returned list has a length not exceeding that of the query_map:
len(result) <= len(query_map)
"""
devices = (
await self.get_e2e_device_keys_and_signatures(
# Because these are (user_id, device_id) tuples with all
# device_ids not being None, the returned list's length will not
# exceed that of query_map.
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
Expand Down
112 changes: 111 additions & 1 deletion tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)

# Get all device updates ever meant for this remote
# Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)
Expand Down Expand Up @@ -155,6 +155,116 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
# Check the newly-added device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])

def test_get_device_updates_by_remote_cross_signing_key_updates(
self,
) -> None:
"""
Tests that `get_device_updates_by_remote` limits the length of the return value
properly when cross-signing key updates are present.
Current behaviour is that the cross-signing key updates will always come in pairs,
even if that means leaving an earlier batch one EDU short of the limit.
"""

assert self.hs.is_mine_id(
"@user_id:test"
), "Test not valid: this MXID should be considered local"

self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"master",
{
"keys": {
"ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)
self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"self_signing",
{
"keys": {
"ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)

# Add some device updates with sequential `stream_id`s
# Note that the public cross-signing keys occupy the same space as device IDs,
# so also notify that those have updated.
device_ids = [
"device_id1",
"device_id2",
"fakeMaster",
"fakeSelfSigning",
]

self.get_success(
self.store.add_device_change_to_streams(
"@user_id:test", device_ids, ["somehost"]
)
)

# Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)

# Here we expect the device updates for `device_id1` and `device_id2`.
# That means we only receive 2 updates this time around.
# If we had a higher limit, we would expect to see the pair of
# (unstable-prefixed & unprefixed) signing key updates for the device
# represented by `fakeMaster` and `fakeSelfSigning`.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# Our implementation only sends these two variants together, so we get
# a short batch.
self.assertEqual(len(device_updates), 2, device_updates)

# Check the first two devices (device_id1, device_id2) came out.
self._check_devices_in_updates(device_ids[:2], device_updates)

# Get more device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)

# The next 2 updates should be a cross-signing key update
# (the master key update and the self-signing key update are combined into
# one 'signing key update', but the cross-signing key update is emitted
# twice, once with an unprefixed type and once again with an unstable-prefixed type)
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
device_updates[0][0], "m.signing_key_update", device_updates[0]
)
self.assertEqual(
device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
)
reivilibre marked this conversation as resolved.
Show resolved Hide resolved

# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])

def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
Expand Down