diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 945be1b2f62c..eac7d8590595 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -291,6 +291,7 @@ def _get_auth_chain_ids_using_cover_index_txn( txn.execute(sql % (clause,), args) for event_id, chain_id, sequence_number in txn: + logger.debug(f"JASON: s1d, {event_id}: {chain_id}, {sequence_number}") section_1_rows.add((event_id, chain_id, sequence_number)) with Measure( @@ -331,6 +332,59 @@ def _get_auth_chain_ids_using_cover_index_txn( section_2_rows = set() + with Measure( + self.hs.get_clock(), + "_get_auth_chain_ids_using_cover_index_txn.section_2_cache_retrieval", + ): + # Take a copy of the event_chains dict, as it will be mutated to remove + # entries that don't have to be pulled from the database later. + logger.debug(f"JASON: s2c, event_chains before: {event_chains}") + for chain_id, seq_no in dict(event_chains).items(): + logger.debug(f"JASON: looking for cache entry: {chain_id}, {seq_no}") + s2_cache_entry = self._authchain_links_list.get(chain_id) + # the seq_no above references a specific set of chains to start + # processing at. The cache will contain(if an entry is there at all) all + # chains below that value. If newer information has been added since the + # last time the cache was loaded, then the seq_no will have no matching + # origin_sequence_number below(a literal 'equals'). Watch for that, as + # if it doesn't exist, then it's time to reload. + s2_checkpoint = True + if s2_cache_entry is not None: + logger.debug(f"JASON: s2c, cache entry found: {chain_id}, {s2_cache_entry}") + for origin_seq_number, target_set_info in s2_cache_entry.items(): + # If we pass the checkpoint, this will be removed from the cache + + # NOTE Always plan on a reload unless we hit a match where: + # seq_no - 1 == origin seq number, indicating we found the max + # but only if seq_no is greater than 1 + logger.debug(f"JASON FOCUS: seq_no {seq_no} origin_seq_number {origin_seq_number}") + if (seq_no > 1 and (seq_no - 1) == origin_seq_number) or (seq_no == 1 and seq_no == origin_seq_number): + logger.debug("JASON: not reloading cache entry") + s2_checkpoint = False + # This condition gates that a sequence number GREATER than what + # is needed is not used. + if origin_seq_number <= seq_no: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + for target_chain_id, target_seq_no in target_set_info: + # We use a (0, 0) tuple as a placeholder in the cache + # to represent that this particular target set doesn't + # exist in the database and therefore will never be + # in the cache. Typically, this is an origin event and + # will have nothing prior to it, hence no chain. + if (target_chain_id, target_seq_no) != (0, 0): + # This is slightly more optimized than using max() + target_seq_max_result = chains.get( + target_chain_id, 0 + ) + if target_seq_no > target_seq_max_result: + chains[target_chain_id] = target_seq_no + + if not s2_checkpoint: + del event_chains[chain_id] + + logger.debug(f"JASON: s2c, event chains after: {event_chains}") with Measure( self.hs.get_clock(), "_get_auth_chain_ids_using_cover_index_txn.section_2_database", @@ -390,6 +444,31 @@ def _get_auth_chain_ids_using_cover_index_txn( if max_sequence_result > 0: chains[chain_id] = max_sequence_result + with Measure( + self.hs.get_clock(), + "_get_auth_chain_ids_using_cover_index_txn.section_2_postprocessing_cache", + ): + # For this block, first build the cache entries in an efficient way, then + # set them into the cache itself. + cache_entries: Dict[int, Dict[int, Set[Tuple[int, int]]]] = {} + seen_during_batching = set() + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in section_2_rows: + logger.debug(f"JASON: s2dpc, processing row {origin_chain_id}, {origin_sequence_number}: {target_chain_id}, {target_sequence_number}") + seen_during_batching.add(origin_chain_id) + cache_entries.setdefault(origin_chain_id, {}).setdefault( + origin_sequence_number, set() + ).add((target_chain_id, target_sequence_number)) + + for origin_chain_id, cache_entry in cache_entries.items(): + logger.debug(f"JASON: s2dpc, adding to cache {origin_chain_id}: {cache_entry}") + self._authchain_links_list.set(origin_chain_id, cache_entry) + logger.debug(f"JASON: seen during batching {seen_during_batching}") + # Now for each chain we figure out the maximum sequence number reachable # from *any* event ID. Events with a sequence less than that are in the # auth chain. @@ -453,6 +532,7 @@ def _get_auth_chain_ids_using_cover_index_txn( "_get_auth_chain_ids_using_cover_index_txn.section_3_postprocessing_cache", ): for event_id, chain_id, sequence_number in section_3_rows: + logger.debug(f"JASON: s3d, {event_id}: {chain_id}, {sequence_number}") s3_cache_entry = self._authchain_chain_info_to_event_id.get( chain_id, update_last_access=False )