diff --git a/changelog.d/12468.misc b/changelog.d/12468.misc new file mode 100644 index 000000000000..3d5d25247f64 --- /dev/null +++ b/changelog.d/12468.misc @@ -0,0 +1 @@ +Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 5eb545c86e2a..df1e9c1b831d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,7 +41,6 @@ from typing_extensions import Literal from twisted.enterprise import adbapi -from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -794,7 +793,7 @@ async def _runInteraction() -> R: # We also wait until everything above is done before releasing the # `CancelledError`, so that logging contexts won't get used after they have been # finished. - return await delay_cancellation(defer.ensureDeferred(_runInteraction())) + return await delay_cancellation(_runInteraction()) async def runWithConnection( self, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 650e44de2287..e27c5d298f6a 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -14,6 +14,7 @@ # limitations under the License. import abc +import asyncio import collections import inspect import itertools @@ -25,6 +26,7 @@ Awaitable, Callable, Collection, + Coroutine, Dict, Generic, Hashable, @@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": return new_deferred -def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": - """Delay cancellation of a `Deferred` until it resolves. +@overload +def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + ... + + +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves. Has the same effect as `stop_cancellation`, but the returned `Deferred` will not - resolve with a `CancelledError` until the original `Deferred` resolves. + resolve with a `CancelledError` until the original awaitable resolves. Args: - deferred: The `Deferred` to protect against cancellation. May optionally follow - the Synapse logcontext rules. + deferred: The coroutine or `Deferred` to protect against cancellation. May + optionally follow the Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`. - The new `Deferred` will not propagate cancellation through to the original. - When cancelled, the new `Deferred` will wait until the original `Deferred` - resolves before failing with a `CancelledError`. + A new `Deferred`, which will contain the result of the original coroutine or + `Deferred`. The new `Deferred` will not propagate cancellation through to the + original coroutine or `Deferred`. - The new `Deferred` will follow the Synapse logcontext rules if `deferred` + When cancelled, the new `Deferred` will wait until the original coroutine or + `Deferred` resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `awaitable` follows the Synapse logcontext rules. Otherwise the new `Deferred` should be wrapped with `make_deferred_yieldable`. """ + # First, convert the awaitable into a `Deferred`. + if isinstance(awaitable, defer.Deferred): + deferred = awaitable + elif asyncio.iscoroutine(awaitable): + # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant + # type-checking, but we'd need Twisted >= 21.2. + deferred = defer.ensureDeferred(awaitable) + else: + # We have no idea what to do with this awaitable. + # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from + # `make_awaitable`, and let the caller `await` it normally. + return awaitable + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: # before the new deferred is cancelled, we `pause` it to stop the cancellation # propagating. we then `unpause` it once the wrapped deferred completes, to diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index e5bc416de12f..daacc54c72bd 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -382,7 +382,7 @@ def test_cancellation(self): class DelayCancellationTests(TestCase): """Tests for the `delay_cancellation` function.""" - def test_cancellation(self): + def test_deferred_cancellation(self): """Test that cancellation of the new `Deferred` waits for the original.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -403,6 +403,35 @@ def test_cancellation(self): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) + def test_coroutine_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + blocking_deferred: "Deferred[None]" = Deferred() + completion_deferred: "Deferred[None]" = Deferred() + + async def task(): + await blocking_deferred + completion_deferred.callback(None) + # Raise an exception. Twisted should consume it, otherwise unwanted + # tracebacks will be printed in logs. + raise ValueError("abc") + + wrapper_deferred = delay_cancellation(task()) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + blocking_deferred.called, "Cancellation was propagated too deep" + ) + self.assertFalse(completion_deferred.called) + + # Unblock the task. + blocking_deferred.callback(None) + self.assertTrue(completion_deferred.called) + + # Now that the original coroutine has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + def test_suppresses_second_cancellation(self): """Test that a second cancellation is suppressed. @@ -451,7 +480,7 @@ async def inner(): async def outer(): with LoggingContext("c") as c: try: - await delay_cancellation(defer.ensureDeferred(inner())) + await delay_cancellation(inner()) self.fail("`CancelledError` was not raised") except CancelledError: self.assertEqual(c, current_context())