Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up cached function access #2075

Merged
merged 7 commits into from
Mar 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event
)
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred


@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.get_invited_rooms_for_user)(user_id),
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
invites = yield store.get_invited_rooms_for_user(user_id)
joins = yield store.get_rooms_for_user(user_id)

my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
Expand Down
7 changes: 6 additions & 1 deletion synapse/util/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def errback(f):
deferred.addCallbacks(callback, errback)

def observe(self):
"""Observe the underlying deferred.

Can return either a deferred if the underlying deferred is still pending
(or has failed), or the actual value. Callers may need to use maybeDeferred.
"""
if not self._result:
d = defer.Deferred()

Expand All @@ -101,7 +106,7 @@ def remove(r):
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
return res if success else defer.fail(res)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet this will break something somewhere :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear: I'm not suggesting doing much about it, other than watching for breakage when it lands.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I'm concerned about this, but it makes cache hits much cheaper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably worth documenting it in a docstring.


def observers(self):
return self._observers
Expand Down
42 changes: 35 additions & 7 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,20 @@ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
)

self.num_args = num_args

# list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1]

# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
self.arg_defaults = dict(zip(
all_args[-len(arg_spec.defaults):],
arg_spec.defaults
))
else:
self.arg_defaults = {}

if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
Expand Down Expand Up @@ -289,18 +301,31 @@ def __get__(self, obj, objtype=None):
iterable=self.iterable,
)

def get_cache_key(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.

We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]

@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)

# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None

arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
cache_key = tuple(get_cache_key(args, kwargs))

# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
Expand Down Expand Up @@ -341,7 +366,10 @@ def onErr(f):
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()

return logcontext.make_deferred_yieldable(observer)
if isinstance(observer, defer.Deferred):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_deferred_yieldable will work ok with a non-deferred, so I think this is redundant. otoh I guess it optimises the common path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because some of the speed up comes from not having to bounce through all the deferred stuff (which is much more complicated than just unwrapping to get the value) at the call sites.

Though we can also leave that to another PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood, leave it how it is

return logcontext.make_deferred_yieldable(observer)
else:
return observer

wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
Expand Down
3 changes: 2 additions & 1 deletion synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)(
defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
room_id,
)
for room_id in frozenset(e.room_id for e in events)
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def func(self, key):

a.func.prefill(("foo",), ObservableDeferred(d))

self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(a.func("foo"), d.result)
self.assertEquals(callcount[0], 0)

@defer.inlineCallbacks
Expand Down
38 changes: 38 additions & 0 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,41 @@ def do_lookup():
logcontext.LoggingContext.sentinel)

return d1

@defer.inlineCallbacks
def test_cache_default_args(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2=2, arg3=3):
return self.mock(arg1, arg2, arg3)

obj = Cls()

obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()

# a call with same params shouldn't call the mock again
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_not_called()
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3, 3)
obj.mock.reset_mock()

# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
4 changes: 3 additions & 1 deletion tests/util/test_snapshot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def test_get_set(self):
# before the cache expires returns a resolved deferred.
get_result_at_11 = self.cache.get(11, "key")
self.assertIsNotNone(get_result_at_11)
self.assertTrue(get_result_at_11.called)
if isinstance(get_result_at_11, Deferred):
# The cache may return the actual result rather than a deferred
self.assertTrue(get_result_at_11.called)

# Check that getting the key after the deferred has resolved
# after the cache expires returns None
Expand Down