From 64f8ae29cddd50a43bf6c4b005a3efade05e634c Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 8 Mar 2022 14:50:10 +0000 Subject: [PATCH 1/7] Add `delay_cancellation` utility function `delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah --- synapse/util/async_helpers.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac6a..2389f2e0b2ba 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -695,3 +695,59 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": new_deferred: defer.Deferred[T] = defer.Deferred() deferred.chainDeferred(new_deferred) return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules if `all` is `False`. + all: `True` to delay multiple cancellations. `False` to delay only the first + cancellation. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + + The new `Deferred` will only follow the Synapse logcontext rules if `all` is + `True` and `deferred` follows the Synapse logcontext rules. Otherwise the new + `Deferred` should be wrapped with `make_deferred_yieldable`. + """ + + def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]: + """Insert another `Deferred` into the chain to delay cancellation. + + Called when the original `Deferred` resolves or the new `Deferred` is + cancelled. + """ + failure.trap(CancelledError) + + if deferred.called and not deferred.paused: + # The `CancelledError` came from the original `Deferred`. Pass it through. + return failure + + # Construct another `Deferred` that will only fail with the `CancelledError` + # once the original `Deferred` resolves. + delay_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(delay_deferred) + + if all: + # Intercept cancellations recursively. Each cancellation will cause another + # `Deferred` to be inserted into the chain. + delay_deferred.addErrback(cancel_errback) + + # Override the result with the `CancelledError`. + delay_deferred.addBoth(lambda _: failure) + + return delay_deferred + + new_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(new_deferred) + new_deferred.addErrback(cancel_errback) + return new_deferred From 9854e294641a166902c9f5fe4a36340fb2d8b37e Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 8 Mar 2022 17:24:14 +0000 Subject: [PATCH 2/7] Fix logcontexts when `@cached` and `@cachedList` lookups are cancelled `@cached` and `@cachedList` must wait until the wrapped method has completed before raising `CancelledError`s, otherwise the wrapped method will continue running in the background with a logging context that has been marked as finished. Signed-off-by: Sean Quah --- synapse/util/caches/descriptors.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f14b..89509ba9c5e7 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -40,6 +40,7 @@ from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError +from synapse.util.async_helpers import delay_cancellation from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache @@ -322,6 +323,11 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = cache.set(cache_key, ret, callback=invalidate_callback) + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + ret = delay_cancellation(ret, all=True) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -482,6 +488,11 @@ def errback_all(f: Failure) -> None: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( lambda _: results, unwrapFirstError ) + if missing: + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + d = delay_cancellation(d, all=True) return make_deferred_yieldable(d) else: return defer.succeed(results) From 44d93c1c55b5094ba6277f52397b534aacf428d9 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 2 Mar 2022 17:11:16 +0000 Subject: [PATCH 3/7] Add basic cancellation tests for `@cached` and `@cachedList` decorators Signed-off-by: Sean Quah --- tests/util/caches/test_descriptors.py | 57 ++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcdaf1..b511dfc94d70 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -17,7 +17,7 @@ from unittest import mock from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -28,7 +28,7 @@ make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, lru_cache +from synapse.util.caches.descriptors import cached, cachedList, lru_cache from tests import unittest from tests.test_utils import get_awaitable_result @@ -415,6 +415,31 @@ def func3(self, key, cache_context): obj.invalidate() top_invalidate.assert_called_once() + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + async def fn(self, arg1): + await complete_lookup + return str(arg1) + + obj = Cls() + + d1 = obj.fn(123) + d2 = obj.fn(123) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Cancel `d1`, which is the lookup that caused `fn` to run. + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, "123") + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached @@ -787,3 +812,31 @@ async def list_fn(self, args1, arg2): obj.fn.invalidate((10, 2)) invalidate0.assert_called_once() invalidate1.assert_called_once() + + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + def fn(self, arg1): + pass + + @cachedList("fn", "args") + async def list_fn(self, args): + await complete_lookup + return {arg: str(arg) for arg in args} + + obj = Cls() + + d1 = obj.list_fn([123, 456]) + d2 = obj.list_fn([123, 456, 789]) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) From 9dc3a321e0e0591e00ce6e7cedbd1634a7662054 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 3 Mar 2022 13:23:43 +0000 Subject: [PATCH 4/7] Add tests for logcontexts during `@cached` and `@cachedList` cancellation Signed-off-by: Sean Quah --- tests/util/caches/test_descriptors.py | 90 +++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index b511dfc94d70..f48f6fd54773 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -440,6 +440,49 @@ async def fn(self, arg1): self.failureResultOf(d1, CancelledError) self.assertEqual(d2.result, "123") + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + async def fn(self, arg1): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return str(arg1) + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.fn(123) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached @@ -840,3 +883,50 @@ async def list_fn(self, args): complete_lookup.callback(None) self.failureResultOf(d1, CancelledError) self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + def fn(self, arg1): + pass + + @cachedList("fn", "args") + async def list_fn(self, args): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return {arg: str(arg) for arg in args} + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.list_fn([123]) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) From 0f29d56dcdc8bd4a20f7d3cf408bdb5b8032979e Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 8 Mar 2022 18:07:27 +0000 Subject: [PATCH 5/7] Add newsfile Signed-off-by: Sean Quah --- changelog.d/12183.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12183.misc diff --git a/changelog.d/12183.misc b/changelog.d/12183.misc new file mode 100644 index 000000000000..dd441bb64ff7 --- /dev/null +++ b/changelog.d/12183.misc @@ -0,0 +1 @@ +Add cancellation support to `@cached` and `@cachedList` decorators. From 966cdc2654b975d52fded3b880ad2d1921e39bc7 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 14 Mar 2022 18:07:51 +0000 Subject: [PATCH 6/7] Fixup merge and use latest version of `delay_cancellation` --- synapse/util/async_helpers.py | 56 ------------------------------ synapse/util/caches/descriptors.py | 4 +-- 2 files changed, 2 insertions(+), 58 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 55900922abd1..69c8c1baa9fc 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -731,59 +731,3 @@ def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred - - -def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]": - """Delay cancellation of a `Deferred` until it resolves. - - Has the same effect as `stop_cancellation`, but the returned `Deferred` will not - resolve with a `CancelledError` until the original `Deferred` resolves. - - Args: - deferred: The `Deferred` to protect against cancellation. Must not follow the - Synapse logcontext rules if `all` is `False`. - all: `True` to delay multiple cancellations. `False` to delay only the first - cancellation. - - Returns: - A new `Deferred`, which will contain the result of the original `Deferred`. - The new `Deferred` will not propagate cancellation through to the original. - When cancelled, the new `Deferred` will wait until the original `Deferred` - resolves before failing with a `CancelledError`. - - The new `Deferred` will only follow the Synapse logcontext rules if `all` is - `True` and `deferred` follows the Synapse logcontext rules. Otherwise the new - `Deferred` should be wrapped with `make_deferred_yieldable`. - """ - - def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]: - """Insert another `Deferred` into the chain to delay cancellation. - - Called when the original `Deferred` resolves or the new `Deferred` is - cancelled. - """ - failure.trap(CancelledError) - - if deferred.called and not deferred.paused: - # The `CancelledError` came from the original `Deferred`. Pass it through. - return failure - - # Construct another `Deferred` that will only fail with the `CancelledError` - # once the original `Deferred` resolves. - delay_deferred: "defer.Deferred[T]" = defer.Deferred() - deferred.chainDeferred(delay_deferred) - - if all: - # Intercept cancellations recursively. Each cancellation will cause another - # `Deferred` to be inserted into the chain. - delay_deferred.addErrback(cancel_errback) - - # Override the result with the `CancelledError`. - delay_deferred.addBoth(lambda _: failure) - - return delay_deferred - - new_deferred: "defer.Deferred[T]" = defer.Deferred() - deferred.chainDeferred(new_deferred) - new_deferred.addErrback(cancel_errback) - return new_deferred diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 34ff030b61fc..eda92d864dea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -354,7 +354,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: # We started a new call to `self.orig`, so we must always wait for it to # complete. Otherwise we might mark our current logging context as # finished while `self.orig` is still using it in the background. - ret = delay_cancellation(ret, all=True) + ret = delay_cancellation(ret) return make_deferred_yieldable(ret) @@ -520,7 +520,7 @@ def errback_all(f: Failure) -> None: # We started a new call to `self.orig`, so we must always wait for it to # complete. Otherwise we might mark our current logging context as # finished while `self.orig` is still using it in the background. - d = delay_cancellation(d, all=True) + d = delay_cancellation(d) return make_deferred_yieldable(d) else: return defer.succeed(results) From 356e92a90bd90a514c50c906d1add8c88ecf77df Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 14 Mar 2022 18:28:34 +0000 Subject: [PATCH 7/7] @cachedList now takes kwargs --- tests/util/caches/test_descriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3f357cd61b10..48e616ac7419 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -943,7 +943,7 @@ class Cls: def fn(self, arg1): pass - @cachedList("fn", "args") + @cachedList(cached_method_name="fn", list_name="args") async def list_fn(self, args): await complete_lookup return {arg: str(arg) for arg in args} @@ -978,7 +978,7 @@ class Cls: def fn(self, arg1): pass - @cachedList("fn", "args") + @cachedList(cached_method_name="fn", list_name="args") async def list_fn(self, args): await make_deferred_yieldable(complete_lookup) self.inner_context_was_finished = current_context().finished