Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rename Cache to DeferredCache, and related changes #8548

Merged
merged 5 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8548.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rename `Cache` to `DeferredCache`, to better reflect its purpose.
6 changes: 3 additions & 3 deletions synapse/replication/slave/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache
from synapse.util.caches.deferred_cache import DeferredCache

from ._base import BaseSlavedStore

Expand All @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)

self.client_ip_last_seen = Cache(
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
) # type: DeferredCache[tuple, int]

async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
from synapse.util.caches.deferred_cache import DeferredCache

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -410,7 +410,7 @@ def _prune_old_user_ips_txn(txn):
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):

self.client_ip_last_seen = Cache(
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)

Expand Down
5 changes: 3 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr

Expand Down Expand Up @@ -1004,7 +1005,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):

# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000
)

Expand Down
5 changes: 3 additions & 2 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._cleanup_old_transaction_ids,
)

self._get_event_cache = Cache(
self._get_event_cache = DeferredCache(
"*getEvent*",
keylen=3,
max_entries=hs.config.caches.event_cache_size,
Expand Down
292 changes: 292 additions & 0 deletions synapse/util/caches/deferred_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast

from prometheus_client import Gauge

from twisted.internet import defer

from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry

cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)


KT = TypeVar("KT")
VT = TypeVar("VT")


class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()


class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.

It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""

__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)

def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
"""
cache_type = TreeCache if tree else dict

# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]

# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)

self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)

@property
def max_entries(self):
return self.cache.max_size

def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)

def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))

def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)

def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches.

Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics

Returns:
Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred

val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val

if update_metrics:
self.metrics.inc_misses()

if default is _Sentinel.sentinel:
raise KeyError()
else:
return default

def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")

callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)

existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()

self._pending_deferred_cache[key] = entry

def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.

Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True

# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

return False

def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()

def eb(_fail):
compare_and_pop()
entry.invalidate()

# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable

def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)

def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)

# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)

# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()

def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)

# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()

def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()


class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]

def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False

def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
Loading