Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preparatory work for tweaking performance of auth chain lookups #16833

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16833.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparatory work for tweaking performance of auth chain lookups.
153 changes: 127 additions & 26 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def __init__(
unique_columns=("event_id", "room_id"),
)

self.db_pool.updates.register_background_index_update(
update_name="event_auth_chain_links_origin_index",
index_name="event_auth_chain_links_origin_index",
table="event_auth_chain_links",
columns=("origin_chain_id", "origin_sequence_number"),
)

async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
Expand Down Expand Up @@ -271,38 +278,63 @@ def _get_auth_chain_ids_using_cover_index_txn(

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the %s kind of buries the readability but I don't have a concrete suggestion to offer tbh

UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}

# Add all linked chains reachable from initial set of chains.
for batch2 in batch_iter(event_chains, 1000):
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 100))
chains_to_fetch.difference_update(batch2)
Comment on lines +312 to +313
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouch, is this the best way to pop 100 items out of a set?
If it is, fine, but let's just say I am missing Rust here :-)

clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

for chain_id in links:
if chain_id not in event_chains:
continue

_materialize(chain_id, event_chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)

# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
Expand Down Expand Up @@ -529,41 +561,64 @@ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:

chains[chain_id] = max(seq_no, chains.get(chain_id, 0))

# Now we look up all links for the chains we have, adding chains to
# set_to_chain that are reachable from each set.
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch2 in batch_iter(set(seen_chains), 1000):
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 100))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
for chains in set_to_chain:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
continue

seen_chains.add(target_chain_id)
_materialize(chain_id, chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)
seen_chains.update(chains)

# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
Expand Down Expand Up @@ -2103,3 +2158,49 @@ def delete_event_auth(txn: LoggingTransaction) -> bool:
)

return batch_size


def _materialize(
origin_chain_id: int,
origin_sequence_number: int,
links: Dict[int, List[Tuple[int, int, int]]],
materialized: Dict[int, int],
) -> None:
"""Helper function for fetching auth chain links. For a given origin chain
ID / sequence number and a dictionary of links, updates the materialized
dict with the reachable chains.

To get a dict of all chains reachable from a set of chains this function can
be called in a loop, once per origin chain with the same links and
materialized args. The materialized dict will the result.

Args:
origin_chain_id, origin_sequence_number
links: map of the links between chains as a dict from origin chain ID
to list of 3-tuples of origin sequence number, target chain ID and
target sequence number.
materialized: dict to update with new reachability information, as a
map from chain ID to max sequence number reachable.
"""

# Do a standard graph traversal.
stack = [(origin_chain_id, origin_sequence_number)]

while stack:
c, s = stack.pop()

chain_links = links.get(c, [])
for (
sequence_number,
target_chain_id,
target_sequence_number,
) in chain_links:
# Ignore any links that are higher up the chain
if sequence_number > s:
continue

# Check if we have already visited the target chain before, if so we
# can skip it.
if materialized.get(target_chain_id, 0) < target_sequence_number:
stack.append((target_chain_id, target_sequence_number))
materialized[target_chain_id] = target_sequence_number
2 changes: 1 addition & 1 deletion synapse/storage/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
#

SCHEMA_VERSION = 83 # remember to update the list below when updating
SCHEMA_VERSION = 84 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema

This should be incremented whenever the codebase changes its requirements on the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2023 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.

-- Force the statistics for these tables to show that the number of distinct
-- chain IDs are proportional to the total rows, as postgres has trouble
-- figuring that out by itself.
ALTER TABLE event_auth_chain_links ALTER origin_chain_id SET (n_distinct = -0.5);
ALTER TABLE event_auth_chain_links ALTER target_chain_id SET (n_distinct = -0.5);
16 changes: 16 additions & 0 deletions synapse/storage/schema/main/delta/84/02_auth_links_index.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2023 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.


INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8402, 'event_auth_chain_links_origin_index', '{}');
Loading