Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
realtyem committed Nov 5, 2023
1 parent 55dbbd3 commit a489ea3
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _get_auth_chain_ids_using_cover_index_txn(
txn.execute(sql % (clause,), args)

for event_id, chain_id, sequence_number in txn:
logger.debug(f"JASON: s1d, {event_id}: {chain_id}, {sequence_number}")
section_1_rows.add((event_id, chain_id, sequence_number))

with Measure(
Expand Down Expand Up @@ -331,6 +332,59 @@ def _get_auth_chain_ids_using_cover_index_txn(

section_2_rows = set()

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_cache_retrieval",
):
# Take a copy of the event_chains dict, as it will be mutated to remove
# entries that don't have to be pulled from the database later.
logger.debug(f"JASON: s2c, event_chains before: {event_chains}")
for chain_id, seq_no in dict(event_chains).items():
logger.debug(f"JASON: looking for cache entry: {chain_id}, {seq_no}")
s2_cache_entry = self._authchain_links_list.get(chain_id)
# the seq_no above references a specific set of chains to start
# processing at. The cache will contain(if an entry is there at all) all
# chains below that value. If newer information has been added since the
# last time the cache was loaded, then the seq_no will have no matching
# origin_sequence_number below(a literal 'equals'). Watch for that, as
# if it doesn't exist, then it's time to reload.
s2_checkpoint = True
if s2_cache_entry is not None:
logger.debug(f"JASON: s2c, cache entry found: {chain_id}, {s2_cache_entry}")
for origin_seq_number, target_set_info in s2_cache_entry.items():
# If we pass the checkpoint, this will be removed from the cache

# NOTE Always plan on a reload unless we hit a match where:
# seq_no - 1 == origin seq number, indicating we found the max
# but only if seq_no is greater than 1
logger.debug(f"JASON FOCUS: seq_no {seq_no} origin_seq_number {origin_seq_number}")
if (seq_no > 1 and (seq_no - 1) == origin_seq_number) or (seq_no == 1 and seq_no == origin_seq_number):
logger.debug("JASON: not reloading cache entry")
s2_checkpoint = False
# This condition gates that a sequence number GREATER than what
# is needed is not used.
if origin_seq_number <= seq_no:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
for target_chain_id, target_seq_no in target_set_info:
# We use a (0, 0) tuple as a placeholder in the cache
# to represent that this particular target set doesn't
# exist in the database and therefore will never be
# in the cache. Typically, this is an origin event and
# will have nothing prior to it, hence no chain.
if (target_chain_id, target_seq_no) != (0, 0):
# This is slightly more optimized than using max()
target_seq_max_result = chains.get(
target_chain_id, 0
)
if target_seq_no > target_seq_max_result:
chains[target_chain_id] = target_seq_no

if not s2_checkpoint:
del event_chains[chain_id]

logger.debug(f"JASON: s2c, event chains after: {event_chains}")
with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_database",
Expand Down Expand Up @@ -390,6 +444,31 @@ def _get_auth_chain_ids_using_cover_index_txn(
if max_sequence_result > 0:
chains[chain_id] = max_sequence_result

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_postprocessing_cache",
):
# For this block, first build the cache entries in an efficient way, then
# set them into the cache itself.
cache_entries: Dict[int, Dict[int, Set[Tuple[int, int]]]] = {}
seen_during_batching = set()
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in section_2_rows:
logger.debug(f"JASON: s2dpc, processing row {origin_chain_id}, {origin_sequence_number}: {target_chain_id}, {target_sequence_number}")
seen_during_batching.add(origin_chain_id)
cache_entries.setdefault(origin_chain_id, {}).setdefault(
origin_sequence_number, set()
).add((target_chain_id, target_sequence_number))

for origin_chain_id, cache_entry in cache_entries.items():
logger.debug(f"JASON: s2dpc, adding to cache {origin_chain_id}: {cache_entry}")
self._authchain_links_list.set(origin_chain_id, cache_entry)
logger.debug(f"JASON: seen during batching {seen_during_batching}")

# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
Expand Down Expand Up @@ -453,6 +532,7 @@ def _get_auth_chain_ids_using_cover_index_txn(
"_get_auth_chain_ids_using_cover_index_txn.section_3_postprocessing_cache",
):
for event_id, chain_id, sequence_number in section_3_rows:
logger.debug(f"JASON: s3d, {event_id}: {chain_id}, {sequence_number}")
s3_cache_entry = self._authchain_chain_info_to_event_id.get(
chain_id, update_last_access=False
)
Expand Down

0 comments on commit a489ea3

Please sign in to comment.