Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Back out in-flight state caching changes. #12126

Merged
merged 5 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion changelog.d/10870.misc

This file was deleted.

1 change: 0 additions & 1 deletion changelog.d/11608.misc

This file was deleted.

1 change: 0 additions & 1 deletion changelog.d/11610.misc

This file was deleted.

1 change: 0 additions & 1 deletion changelog.d/12033.misc

This file was deleted.

1 change: 1 addition & 0 deletions changelog.d/12126.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Back out in-flight state caching changes.
243 changes: 25 additions & 218 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,11 @@
# limitations under the License.

import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple

import attr
from sortedcontainers import SortedDict

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
Expand All @@ -42,12 +29,6 @@
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache

Expand All @@ -56,8 +37,8 @@

logger = logging.getLogger(__name__)


MAX_STATE_DELTA_HOPS = 100
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5


@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -73,24 +54,6 @@ def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0


def state_filter_rough_priority_comparator(
state_filter: StateFilter,
) -> Tuple[int, int]:
"""
Returns a comparable value that roughly indicates the relative size of this
state filter compared to others.
'Larger' state filters should sort first when using ascending order, so
this is essentially the opposite of 'size'.
It should be treated as a rough guide only and should not be interpreted to
have any particular meaning. The representation may also change

The current implementation returns a tuple of the form:
* -1 for include_others, 0 otherwise
* -(number of entries in state_filter.types)
"""
return -int(state_filter.include_others), -len(state_filter.types)


class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""

Expand Down Expand Up @@ -143,12 +106,6 @@ def __init__(
500000,
)

# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}

def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
Expand Down Expand Up @@ -200,7 +157,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
)

async def _get_state_groups_from_groups(
self, groups: Sequence[int], state_filter: StateFilter
self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
Expand Down Expand Up @@ -271,170 +228,6 @@ def _get_state_for_group_using_cache(

return state_filter.filter_state(state_dict_ids), not missing_types

def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.

If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
and there *still* isn't enough information to complete the request by solely
reusing others, a full state filter will be requested to ensure that subsequent
requests can reuse this request.

Used as part of _get_state_for_group_using_inflight_cache.

Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""

inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over

# The list of ongoing requests which will help narrow the current request.
reusable_requests = []

# Iterate over existing requests in roughly biggest-first order.
for request_state_filter in inflight_requests:
request_deferred = inflight_requests[request_state_filter]
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue

reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break

if (
state_filter_left_over != StateFilter.none()
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
):
# There are too many requests for this group.
# To prevent even more from building up, we request the whole
# state filter to guarantee that we can be reused by any subsequent
# requests for this state group.
return (), StateFilter.all()

return reusable_requests, state_filter_left_over

async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.

This request will be tracked in the in-flight request cache and automatically
removed when it is finished.

Used as part of _get_state_for_group_using_inflight_cache.

Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence

# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()

async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)

# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)

# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]

return group_to_state_dict[group]

# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)

# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(
group, SortedDict(state_filter_rough_priority_comparator)
)
group_request_dict[db_state_filter] = observable_deferred

return await make_deferred_yieldable(observable_deferred.observe())

async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.

1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.

Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""

# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)

if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}

gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)

# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)

# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)

async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
Expand Down Expand Up @@ -476,17 +269,31 @@ async def _get_state_for_groups(
if not incomplete_groups:
return state

async def get_from_cache(group: int, state_filter: StateFilter) -> None:
state[group] = await self._get_state_for_group_using_inflight_cache(
group, state_filter
)
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence

await yieldable_gather_results(
get_from_cache,
incomplete_groups,
state_filter,
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()

group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
)

# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)

# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)

return state

def _get_state_for_groups_using_cache(
Expand Down
Loading