From ee4b420561564995d08d950b87c738fd876a0e3c Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 6 May 2022 20:30:12 +0100 Subject: [PATCH 01/25] Add `_test_cancellation_at_every_await` helper method Signed-off-by: Sean Quah --- tests/http/server/_base.py | 425 ++++++++++++++++++++++++++++++++++++- 1 file changed, 424 insertions(+), 1 deletion(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index b9f1a381aa2b..f62f76328c1b 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -12,21 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import itertools +import logging from http import HTTPStatus -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union from unittest import mock +from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone +from twisted.python.failure import Failure from synapse.http.server import ( HTTP_STATUS_REQUEST_CANCELLED, respond_with_html_bytes, respond_with_json, ) +from synapse.http.site import SynapseRequest +from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.types import JsonDict from tests import unittest from tests.server import FakeChannel, ThreadedMemoryReactorClock +from tests.unittest import logcontext_clean + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") class EndpointCancellationTestHelperMixin(unittest.TestCase): @@ -98,3 +111,413 @@ def _test_disconnect( self.assertEqual(code, expected_code) self.assertEqual(request.code, expected_code) self.assertEqual(body, expected_body) + + @logcontext_clean + def _test_cancellation_at_every_await( + self, + reactor: ThreadedMemoryReactorClock, + make_request: Callable[[], FakeChannel], + test_name: str, + expected_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Performs a request repeatedly, disconnecting at successive `await`s, until + one completes. + + Fails if: + * A logging context is lost during cancellation. + * A logging context get restarted after it is marked as finished, eg. if + a request's logging context is used by some processing started by the + request, but the request neglects to cancel that processing or wait for it + to complete. + + Note that "Re-starting finished log context" errors get caught by twisted + and will manifest in a different logging context error at a later point. + When debugging logging context failures, setting a breakpoint in + `logcontext_error` can prove useful. + * A request gets stuck, possibly due to a previous cancellation. + * The request does not return a 499 when the client disconnects. + This implies that a `CancelledError` was swallowed somewhere. + * The request does not return a 200 when allowed to run to completion. + + It is up to the caller to verify that the request returns the correct data when + it finally runs to completion. + + Note that this function can only cover a single code path and does not guarantee + that an endpoint is compatible with cancellation on every code path. + To allow inspection of the code path that is being tested, this function will + log the stack trace at every `await` that gets cancelled. To view these log + lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment + variable, which will include the log lines in `_trial_temp/test.log`. + Alternatively, `_log_for_request` can be modified to write to `sys.stdout`. + + Args: + reactor: The twisted reactor running the request handler. + make_request: A function that initiates the request and returns a + `FakeChannel`. + test_name: The name of the test, which will be logged. + expected_code: The expected status code for the final request that runs to + completion. Defaults to `200`. + + Returns: + The JSON response of the final request that runs to completion. + """ + # To process a request, a coroutine run is created for the async method handling + # the request. That method may then start other coroutine runs, wrapped in + # `Deferred`s. + # + # We would like to trigger a cancellation at the first `await`, re-run the + # request and cancel at the second `await`, and so on. By patching + # `Deferred.__next__`, we can intercept `await`s, track which ones we have or + # have not seen, and force them to block when they wouldn't have. + + # The set of previously seen `await`s. + # Each element is a stringified stack trace. + seen_awaits: Set[Tuple[str, ...]] = set() + + self._log_for_request( + 0, f"Running _test_cancellation_at_every_await for {test_name}..." + ) + + for request_number in itertools.count(1): + ( + Deferred___next__, + unblock_awaits, + get_awaits_seen, + has_seen_new_await, + ) = self.create_deferred___next___patch(seen_awaits, request_number) + + try: + with mock.patch( + "synapse.http.server.respond_with_json", wraps=respond_with_json + ) as respond_mock: + with mock.patch.object( + Deferred, + "__next__", + new=Deferred___next__, + ): + # Start the request. + channel = make_request() + request = channel.request + + if request_number == 0: + self.assertFalse( + respond_mock.called, + "Request finished before we could disconnect - " + "was `await_result=False` passed to `make_request`?", + ) + else: + # Requests after the first may be lucky enough to hit caches + # all the way through and never have to block. + pass + + # Run the request until we see a new `await` which we have not + # yet cancelled at, or it completes. + while not respond_mock.called and not has_seen_new_await(): + previous_awaits_seen = get_awaits_seen() + + reactor.pump([0.0]) + + if get_awaits_seen() == previous_awaits_seen: + # We didn't see any progress. Try advancing the clock. + reactor.pump([1.0]) + + if get_awaits_seen() == previous_awaits_seen: + # We still didn't see any progress. The request might be stuck. + self.fail( + "Request appears to be stuck, possibly due to " + "a previous cancelled request" + ) + + if respond_mock.called: + # The request ran to completion and we are done with testing it. + + # `respond_with_json` writes the response asynchronously, so we + # might have to give the reactor a kick before the channel gets + # the response. + reactor.pump([1.0]) + + self.assertEqual(channel.code, expected_code) + return channel.json_body + else: + # Disconnect the client and wait for the response. + request.connectionLost(reason=ConnectionDone()) + + self._log_for_request(request_number, "--- disconnected ---") + + # We may need to pump the reactor to allow `delay_cancellation`s + # to finish. + if not respond_mock.called: + reactor.pump([0.0]) + + # Try advancing the clock if that didn't work. + if not respond_mock.called: + reactor.pump([1.0]) + + # Mark the request's logging context as finished. If it gets + # activated again, an `AssertionError` will be raised. This + # `AssertionError` will likely be caught by twisted and turned + # into a `Failure`. Instead, a different `AssertionError` will + # be observed when the logging context is deactivated, as it + # wouldn't have tracked resource usage correctly. + if isinstance(request, SynapseRequest) and request.logcontext: + request.logcontext.finished = True + + # Check that the request finished with a 499, + # ie. the `CancelledError` wasn't swallowed. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code = args[1] + self.assertEqual(code, HTTP_STATUS_REQUEST_CANCELLED) + finally: + # Unblock any processing that might be shared between requests. + unblock_awaits() + + assert False, "unreachable" # noqa: B011 + + def create_deferred___next___patch( + self, seen_awaits: Set[Tuple[str, ...]], request_number: int + ) -> Tuple[ + Callable[["Deferred[T]"], "Deferred[T]"], + Callable[[], None], + Callable[[], int], + Callable[[], bool], + ]: + """Creates a function to patch `Deferred.__next__` with, for + `_test_cancellation_at_every_await`. + + Produces a `Deferred.__next__` patch that will intercept `await`s and force them + to block once it sees a new `await`. + + Args: + seen_awaits: The set of stack traces of `await`s that have been previously + seen. When the `Deferred.__next__` patch sees a new `await`, it will add + it to the set. + request_number: The request number to log against. + + Returns: + A tuple containing: + * The method to replace `Deferred.__next__` with. + * A method that will clean up after any `await`s that were forced to block. + This method must be called when done, otherwise processing shared between + multiple requests, such as database queries started by `@cached`, will + become permanently stuck. + * A method returning the running total of intercepted `await`s on + `Deferred`s. + * A method returning `True` once a new `await` has been seen. + """ + original_Deferred___next__ = Deferred.__next__ + + # The number of `await`s on `Deferred`s we have seen so far. + awaits_seen = 0 + + # Whether we have seen a new `await` not in `seen_awaits`. + new_await_seen = False + + # To force `await`s on resolved `Deferred`s to block, we make up a new + # unresolved `Deferred` and return it out of `Deferred.__next__` / + # `coroutine.send()`. We have to resolve it later, in case the + # `await`ing coroutine is part of some shared processing, such as + # `@cached`. + to_unblock: Dict[Deferred, Union[object, Failure]] = {} + + # The last stack we logged. + previous_stack: List[inspect.FrameInfo] = [] + + def unblock_awaits() -> None: + """Unblocks any shared processing that we forced to block.""" + for deferred, result in to_unblock.items(): + deferred.callback(result) + + def get_awaits_seen() -> int: + return awaits_seen + + def has_seen_new_await() -> bool: + return new_await_seen + + def Deferred___next__( + deferred: "Deferred[T]", value: object = None + ) -> "Deferred[T]": + """Intercepts `await`s on `Deferred`s and rigs them to block once we have + seen enough of them. + + `Deferred.__next__` will normally: + * return `self` if unresolved, which will come out of + `coroutine.send()`. + * raise a `StopIteration(result)`, containing the result of the + `await`. + * raise another exception, which will come out of the `await`. + """ + nonlocal awaits_seen + nonlocal new_await_seen + nonlocal previous_stack + + awaits_seen += 1 + + stack = self._get_stack(skip_frames=1) + stack_hash = self._hash_stack(stack) + + if stack_hash not in seen_awaits: + # Block at the current `await` onwards. + seen_awaits.add(stack_hash) + new_await_seen = True + + if not new_await_seen: + # This `await` isn't interesting. Let it proceed normally. + + # Don't log the stack. It's been seen before in a previous run. + previous_stack = stack + + return original_Deferred___next__(deferred, value) + else: + # We want to block at the current `await`. + if deferred.called and not deferred.paused: + # This `Deferred` already has a result. + # We return a new, unresolved, `Deferred` for `_inlineCallbacks` + # to wait on. This blocks the coroutine that did this `await`. + # We queue it up for unblocking later. + new_deferred: "Deferred[T]" = Deferred() + to_unblock[new_deferred] = deferred.result + + self._log_await_stack( + stack, previous_stack, request_number, "force-blocked await" + ) + previous_stack = stack + + return make_deferred_yieldable(new_deferred) + else: + # This `Deferred` does not have a result yet. + # The `await` will block normally, so we don't have to do + # anything. + self._log_await_stack( + stack, previous_stack, request_number, "blocking await" + ) + previous_stack = stack + + return original_Deferred___next__(deferred, value) + + return Deferred___next__, unblock_awaits, get_awaits_seen, has_seen_new_await + + def _log_for_request(self, request_number: int, message: str) -> None: + """Logs a message for an iteration of `_test_cancellation_at_every_await`.""" + # We want consistent alignment when logging stack traces, so ensure the + # logging context has a fixed width name. + with LoggingContext(name=f"request-{request_number:<2}"): + logger.info(message) + + def _log_await_stack( + self, + stack: List[inspect.FrameInfo], + previous_stack: List[inspect.FrameInfo], + request_number: int, + note: str, + ) -> None: + """Logs the stack for an `await` in `_test_cancellation_at_every_await`. + + Only logs the part of the stack that has changed since the previous call. + + Example output looks like: + ``` + delay_cancellation:750 (synapse/util/async_helpers.py:750) + DatabasePool._runInteraction:768 (synapse/storage/database.py:768) + > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891) + ``` + + Args: + stack: The stack to log, as returned by `_get_stack()`. + previous_stack: The previous stack logged, with callers appearing before + callees. + request_number: The request number to log against. + note: A note to attach to the last stack frame, eg. "blocked on await". + """ + for i, frame_info in enumerate(stack[:-1]): + # Skip any frames in common with the previous logging. + if i < len(previous_stack) and frame_info == previous_stack[i]: + continue + + frame = self._format_stack_frame(frame_info) + message = f"{' ' * i}{frame}" + self._log_for_request(request_number, message) + + # Always print the final frame with the `await`. + # If the frame with the `await` started another coroutine run, we may have + # already printed a deeper stack which includes our final frame. We want to + # log where all `await`s happen, so we reprint the frame in this case. + i = len(stack) - 1 + frame_info = stack[i] + frame = self._format_stack_frame(frame_info) + message = f"{' ' * i}> *{note}* at {frame}" + self._log_for_request(request_number, message) + + def _format_stack_frame(self, frame_info: inspect.FrameInfo) -> str: + """Returns a string representation of a stack frame. + + Used for debug logging. + + Returns: + A string, formatted like + "JsonResource._async_render:559 (synapse/http/server.py:559)". + """ + method_name = self._get_stack_frame_method_name(frame_info) + + return ( + f"{method_name}:{frame_info.lineno} " + f"({frame_info.filename}:{frame_info.lineno})" + ) + + def _get_stack(self, skip_frames: int) -> List[inspect.FrameInfo]: + """Captures the stack for a request. + + Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`. + + Used for debug logging. + + Returns: + A list of `inspect.FrameInfo`s, with callers appearing before callees. + """ + stack = [] + + skip_frames += 1 # Also skip `get_stack` itself. + + for frame_info in inspect.stack()[skip_frames:]: + # Skip any twisted `inlineCallbacks` gunk. + if "/twisted/" in frame_info.filename: + continue + + # Exclude the reactor frame, upwards. + method_name = self._get_stack_frame_method_name(frame_info) + if method_name == "ThreadedMemoryReactorClock.advance": + break + + stack.append(frame_info) + + # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the + # entry point for request handling. + if frame_info.function == "wrapped_async_request_handler": + break + + return stack[::-1] + + def _get_stack_frame_method_name(self, frame_info: inspect.FrameInfo) -> str: + """Returns the name of a stack frame's method. + + eg. "JsonResource._async_render". + """ + method_name = frame_info.function + + # Prefix the class name for instance methods. + frame_self = frame_info.frame.f_locals.get("self") + if frame_self: + method = getattr(frame_self, method_name, None) + if method: + method_name = method.__qualname__ + else: + # We couldn't find the method on `self`. + # Make something up. It's useful to know which class "contains" a + # function anyway. + method_name = f"{type(frame_self).__name__} {method_name}" + + return method_name + + def _hash_stack(self, stack: List[inspect.FrameInfo]): + """Turns a stack into a hashable value that can be put into a set.""" + return tuple(self._format_stack_frame(frame) for frame in stack) From 92045c8a62ee9eee02e946fa412a464ea62bec42 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 6 May 2022 20:51:25 +0100 Subject: [PATCH 02/25] Fix mypy thinking `RoomBase.servlets` is a `List[function]` Signed-off-by: Sean Quah --- tests/rest/client/test_rooms.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index d0197aca94a0..1357af738ff6 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, ClassVar, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse @@ -33,7 +33,7 @@ ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus -from synapse.rest import admin +from synapse.rest import RegisterServletsFunc, admin from synapse.rest.client import account, directory, login, profile, room, sync from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester @@ -49,7 +49,12 @@ class RoomBase(unittest.HomeserverTestCase): rmcreator_id: Optional[str] = None - servlets = [room.register_servlets, room.register_deprecated_servlets] + # mypy: `room.register_servlets` has an extra parameter, so mypy thinks `servlets` + # is a `List[function]` without the hint. + servlets: ClassVar[List[RegisterServletsFunc]] = [ + room.register_servlets, + room.register_deprecated_servlets, + ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: From 97b6b1e97800cd4d0d8e524414f7a124615dc7e7 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 6 May 2022 20:52:38 +0100 Subject: [PATCH 03/25] Add tests for `/rooms//members` cancellation Signed-off-by: Sean Quah --- tests/rest/client/test_rooms.py | 59 ++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1357af738ff6..9448f7f13477 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -41,6 +41,7 @@ from synapse.util.stringutils import random_string from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -475,7 +476,7 @@ def test_member_event_from_ban(self) -> None: ) -class RoomsMemberListTestCase(RoomBase): +class RoomsMemberListTestCase(RoomBase, EndpointCancellationTestHelperMixin): """Tests /rooms/$room_id/members/list REST events.""" servlets = RoomBase.servlets + [sync.register_servlets] @@ -595,6 +596,62 @@ def test_get_member_list_mixed_memberships(self) -> None: channel = self.make_request("GET", room_path) self.assertEqual(200, channel.code, msg=channel.result["body"]) + def test_get_member_list_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members` request.""" + room_id = self.helper.create_room_as(self.user_id) + body = self._test_cancellation_at_every_await( + self.reactor, + lambda: self.make_request( + "GET", "/rooms/%s/members" % room_id, await_result=False + ), + test_name="test_get_member_list_cancellation", + ) + + self.assertEqual(len(body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + body["chunk"][0].items(), + ) + + def test_get_member_list_with_at_token_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members?at=` request.""" + room_id = self.helper.create_room_as(self.user_id) + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEqual(200, channel.code) + sync_token = channel.json_body["next_batch"] + + body = self._test_cancellation_at_every_await( + self.reactor, + lambda: self.make_request( + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), + await_result=False, + ), + test_name="test_get_member_list_with_at_token_cancellation", + ) + + self.assertEqual(len(body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + body["chunk"][0].items(), + ) + class RoomsCreateTestCase(RoomBase): """Tests /rooms and /rooms/$room_id REST events.""" From 8d7b49578fb7504926edfa9386a07d5035d2b854 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 6 May 2022 22:59:35 +0100 Subject: [PATCH 04/25] Add tests for `/rooms//state` cancellation Signed-off-by: Sean Quah --- tests/rest/client/test_rooms.py | 43 +++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9448f7f13477..01eb6d28fa53 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -476,6 +476,49 @@ def test_member_event_from_ban(self) -> None: ) +class RoomStateTestCase(RoomBase, EndpointCancellationTestHelperMixin): + """Tests /rooms/$room_id/state.""" + + user_id = "@sid1:red" + + def test_get_state_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state` request.""" + room_id = self.helper.create_room_as(self.user_id) + body = self._test_cancellation_at_every_await( + self.reactor, + lambda: self.make_request( + "GET", "/rooms/%s/state" % room_id, await_result=False + ), + test_name="test_state_cancellation", + ) + + self.assertCountEqual( + [state_event["type"] for state_event in body], + { + "m.room.create", + "m.room.power_levels", + "m.room.join_rules", + "m.room.member", + "m.room.history_visibility", + }, + ) + + def test_get_state_event_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" + room_id = self.helper.create_room_as(self.user_id) + body = self._test_cancellation_at_every_await( + self.reactor, + lambda: self.make_request( + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), + await_result=False, + ), + test_name="test_state_cancellation", + ) + + self.assertEqual(body, {"membership": "join"}) + + class RoomsMemberListTestCase(RoomBase, EndpointCancellationTestHelperMixin): """Tests /rooms/$room_id/members/list REST events.""" From b7aa03988ddc926edd8ed8d3642bf99d571781cd Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 9 May 2022 14:08:49 +0100 Subject: [PATCH 05/25] Add dummy newsfile --- changelog.d/12674.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12674.misc diff --git a/changelog.d/12674.misc b/changelog.d/12674.misc new file mode 100644 index 000000000000..28e4c406b898 --- /dev/null +++ b/changelog.d/12674.misc @@ -0,0 +1 @@ +Dummy newsfile to make CI happy. From c980eff45c10e43684923dd8ef401c0e4b2fee9a Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 27 May 2022 16:27:35 +0100 Subject: [PATCH 06/25] Write a proper newsfile --- changelog.d/12674.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/12674.misc b/changelog.d/12674.misc index 28e4c406b898..c8a8f32f0a88 100644 --- a/changelog.d/12674.misc +++ b/changelog.d/12674.misc @@ -1 +1 @@ -Dummy newsfile to make CI happy. +Add tests for cancellation of `GET /rooms/$room_id/members` and `GET /rooms/$room_id/state` requests. From 4bf4738dfc2155a2c91bdccd824a813937e6e4d3 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:01:59 +0100 Subject: [PATCH 07/25] Reword failure message about `await_result=False` --- tests/http/server/_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index f62f76328c1b..50f5c72c3276 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -202,8 +202,8 @@ def _test_cancellation_at_every_await( if request_number == 0: self.assertFalse( respond_mock.called, - "Request finished before we could disconnect - " - "was `await_result=False` passed to `make_request`?", + "Request finished before we could disconnect - ensure " + "`await_result=False` is passed to `make_request`.", ) else: # Requests after the first may be lucky enough to hit caches From 85a051e5e4ac06dbec4a23d0b11d58538ec3859e Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:02:41 +0100 Subject: [PATCH 08/25] `request_number` starts at 1 --- tests/http/server/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 50f5c72c3276..e373ffd791bc 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -199,7 +199,7 @@ def _test_cancellation_at_every_await( channel = make_request() request = channel.request - if request_number == 0: + if request_number == 1: self.assertFalse( respond_mock.called, "Request finished before we could disconnect - ensure " From 1ec26893a1fe787dddf12ee3465ccd2a56a29256 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:04:41 +0100 Subject: [PATCH 09/25] Use `reactor.advance()` instead of `reactor.pump()` --- tests/http/server/_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index e373ffd791bc..8699d6b631f1 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -215,11 +215,11 @@ def _test_cancellation_at_every_await( while not respond_mock.called and not has_seen_new_await(): previous_awaits_seen = get_awaits_seen() - reactor.pump([0.0]) + reactor.advance(0.0) if get_awaits_seen() == previous_awaits_seen: # We didn't see any progress. Try advancing the clock. - reactor.pump([1.0]) + reactor.advance(1.0) if get_awaits_seen() == previous_awaits_seen: # We still didn't see any progress. The request might be stuck. @@ -234,7 +234,7 @@ def _test_cancellation_at_every_await( # `respond_with_json` writes the response asynchronously, so we # might have to give the reactor a kick before the channel gets # the response. - reactor.pump([1.0]) + reactor.advance(1.0) self.assertEqual(channel.code, expected_code) return channel.json_body @@ -247,11 +247,11 @@ def _test_cancellation_at_every_await( # We may need to pump the reactor to allow `delay_cancellation`s # to finish. if not respond_mock.called: - reactor.pump([0.0]) + reactor.advance(0.0) # Try advancing the clock if that didn't work. if not respond_mock.called: - reactor.pump([1.0]) + reactor.advance(1.0) # Mark the request's logging context as finished. If it gets # activated again, an `AssertionError` will be raised. This From 2cc45c347b6229b5869d7e2eb285aaf327a122bc Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:07:17 +0100 Subject: [PATCH 10/25] Outdent `else` branch --- tests/http/server/_base.py | 54 +++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 8699d6b631f1..3fc3553ad981 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -238,36 +238,36 @@ def _test_cancellation_at_every_await( self.assertEqual(channel.code, expected_code) return channel.json_body - else: - # Disconnect the client and wait for the response. - request.connectionLost(reason=ConnectionDone()) - self._log_for_request(request_number, "--- disconnected ---") + # Disconnect the client and wait for the response. + request.connectionLost(reason=ConnectionDone()) - # We may need to pump the reactor to allow `delay_cancellation`s - # to finish. - if not respond_mock.called: - reactor.advance(0.0) + self._log_for_request(request_number, "--- disconnected ---") + + # We may need to pump the reactor to allow `delay_cancellation`s to + # finish. + if not respond_mock.called: + reactor.advance(0.0) + + # Try advancing the clock if that didn't work. + if not respond_mock.called: + reactor.advance(1.0) - # Try advancing the clock if that didn't work. - if not respond_mock.called: - reactor.advance(1.0) - - # Mark the request's logging context as finished. If it gets - # activated again, an `AssertionError` will be raised. This - # `AssertionError` will likely be caught by twisted and turned - # into a `Failure`. Instead, a different `AssertionError` will - # be observed when the logging context is deactivated, as it - # wouldn't have tracked resource usage correctly. - if isinstance(request, SynapseRequest) and request.logcontext: - request.logcontext.finished = True - - # Check that the request finished with a 499, - # ie. the `CancelledError` wasn't swallowed. - respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code = args[1] - self.assertEqual(code, HTTP_STATUS_REQUEST_CANCELLED) + # Mark the request's logging context as finished. If it gets + # activated again, an `AssertionError` will be raised. This + # `AssertionError` will likely be caught by twisted and turned into + # a `Failure`. Instead, a different `AssertionError` will be + # observed when the logging context is deactivated, as it wouldn't + # have tracked resource usage correctly. + if isinstance(request, SynapseRequest) and request.logcontext: + request.logcontext.finished = True + + # Check that the request finished with a 499, + # ie. the `CancelledError` wasn't swallowed. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code = args[1] + self.assertEqual(code, HTTP_STATUS_REQUEST_CANCELLED) finally: # Unblock any processing that might be shared between requests. unblock_awaits() From 9543d43f71012c19faa85a50092ebcdb43e71577 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:13:11 +0100 Subject: [PATCH 11/25] Un-instance method a bunch of functions --- tests/http/server/_base.py | 224 +++++++++++++++++++------------------ 1 file changed, 114 insertions(+), 110 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 3fc3553ad981..e7046975dcc8 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -174,7 +174,7 @@ def _test_cancellation_at_every_await( # Each element is a stringified stack trace. seen_awaits: Set[Tuple[str, ...]] = set() - self._log_for_request( + _log_for_request( 0, f"Running _test_cancellation_at_every_await for {test_name}..." ) @@ -242,7 +242,7 @@ def _test_cancellation_at_every_await( # Disconnect the client and wait for the response. request.connectionLost(reason=ConnectionDone()) - self._log_for_request(request_number, "--- disconnected ---") + _log_for_request(request_number, "--- disconnected ---") # We may need to pump the reactor to allow `delay_cancellation`s to # finish. @@ -353,8 +353,8 @@ def Deferred___next__( awaits_seen += 1 - stack = self._get_stack(skip_frames=1) - stack_hash = self._hash_stack(stack) + stack = _get_stack(skip_frames=1) + stack_hash = _hash_stack(stack) if stack_hash not in seen_awaits: # Block at the current `await` onwards. @@ -378,7 +378,7 @@ def Deferred___next__( new_deferred: "Deferred[T]" = Deferred() to_unblock[new_deferred] = deferred.result - self._log_await_stack( + _log_await_stack( stack, previous_stack, request_number, "force-blocked await" ) previous_stack = stack @@ -388,7 +388,7 @@ def Deferred___next__( # This `Deferred` does not have a result yet. # The `await` will block normally, so we don't have to do # anything. - self._log_await_stack( + _log_await_stack( stack, previous_stack, request_number, "blocking await" ) previous_stack = stack @@ -397,127 +397,131 @@ def Deferred___next__( return Deferred___next__, unblock_awaits, get_awaits_seen, has_seen_new_await - def _log_for_request(self, request_number: int, message: str) -> None: - """Logs a message for an iteration of `_test_cancellation_at_every_await`.""" - # We want consistent alignment when logging stack traces, so ensure the - # logging context has a fixed width name. - with LoggingContext(name=f"request-{request_number:<2}"): - logger.info(message) - def _log_await_stack( - self, - stack: List[inspect.FrameInfo], - previous_stack: List[inspect.FrameInfo], - request_number: int, - note: str, - ) -> None: - """Logs the stack for an `await` in `_test_cancellation_at_every_await`. +def _log_for_request(request_number: int, message: str) -> None: + """Logs a message for an iteration of `_test_cancellation_at_every_await`.""" + # We want consistent alignment when logging stack traces, so ensure the logging + # context has a fixed width name. + with LoggingContext(name=f"request-{request_number:<2}"): + logger.info(message) - Only logs the part of the stack that has changed since the previous call. - Example output looks like: - ``` - delay_cancellation:750 (synapse/util/async_helpers.py:750) - DatabasePool._runInteraction:768 (synapse/storage/database.py:768) - > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891) - ``` +def _log_await_stack( + stack: List[inspect.FrameInfo], + previous_stack: List[inspect.FrameInfo], + request_number: int, + note: str, +) -> None: + """Logs the stack for an `await` in `_test_cancellation_at_every_await`. - Args: - stack: The stack to log, as returned by `_get_stack()`. - previous_stack: The previous stack logged, with callers appearing before - callees. - request_number: The request number to log against. - note: A note to attach to the last stack frame, eg. "blocked on await". - """ - for i, frame_info in enumerate(stack[:-1]): - # Skip any frames in common with the previous logging. - if i < len(previous_stack) and frame_info == previous_stack[i]: - continue - - frame = self._format_stack_frame(frame_info) - message = f"{' ' * i}{frame}" - self._log_for_request(request_number, message) - - # Always print the final frame with the `await`. - # If the frame with the `await` started another coroutine run, we may have - # already printed a deeper stack which includes our final frame. We want to - # log where all `await`s happen, so we reprint the frame in this case. - i = len(stack) - 1 - frame_info = stack[i] - frame = self._format_stack_frame(frame_info) - message = f"{' ' * i}> *{note}* at {frame}" - self._log_for_request(request_number, message) - - def _format_stack_frame(self, frame_info: inspect.FrameInfo) -> str: - """Returns a string representation of a stack frame. - - Used for debug logging. + Only logs the part of the stack that has changed since the previous call. - Returns: - A string, formatted like - "JsonResource._async_render:559 (synapse/http/server.py:559)". - """ - method_name = self._get_stack_frame_method_name(frame_info) + Example output looks like: + ``` + delay_cancellation:750 (synapse/util/async_helpers.py:750) + DatabasePool._runInteraction:768 (synapse/storage/database.py:768) + > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891) + ``` - return ( - f"{method_name}:{frame_info.lineno} " - f"({frame_info.filename}:{frame_info.lineno})" - ) + Args: + stack: The stack to log, as returned by `_get_stack()`. + previous_stack: The previous stack logged, with callers appearing before + callees. + request_number: The request number to log against. + note: A note to attach to the last stack frame, eg. "blocked on await". + """ + for i, frame_info in enumerate(stack[:-1]): + # Skip any frames in common with the previous logging. + if i < len(previous_stack) and frame_info == previous_stack[i]: + continue - def _get_stack(self, skip_frames: int) -> List[inspect.FrameInfo]: - """Captures the stack for a request. + frame = _format_stack_frame(frame_info) + message = f"{' ' * i}{frame}" + _log_for_request(request_number, message) - Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`. + # Always print the final frame with the `await`. + # If the frame with the `await` started another coroutine run, we may have already + # printed a deeper stack which includes our final frame. We want to log where all + # `await`s happen, so we reprint the frame in this case. + i = len(stack) - 1 + frame_info = stack[i] + frame = _format_stack_frame(frame_info) + message = f"{' ' * i}> *{note}* at {frame}" + _log_for_request(request_number, message) - Used for debug logging. - Returns: - A list of `inspect.FrameInfo`s, with callers appearing before callees. - """ - stack = [] +def _format_stack_frame(frame_info: inspect.FrameInfo) -> str: + """Returns a string representation of a stack frame. - skip_frames += 1 # Also skip `get_stack` itself. + Used for debug logging. - for frame_info in inspect.stack()[skip_frames:]: - # Skip any twisted `inlineCallbacks` gunk. - if "/twisted/" in frame_info.filename: - continue + Returns: + A string, formatted like + "JsonResource._async_render:559 (synapse/http/server.py:559)". + """ + method_name = _get_stack_frame_method_name(frame_info) - # Exclude the reactor frame, upwards. - method_name = self._get_stack_frame_method_name(frame_info) - if method_name == "ThreadedMemoryReactorClock.advance": - break + return ( + f"{method_name}:{frame_info.lineno} ({frame_info.filename}:{frame_info.lineno})" + ) - stack.append(frame_info) - # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the - # entry point for request handling. - if frame_info.function == "wrapped_async_request_handler": - break +def _get_stack(skip_frames: int) -> List[inspect.FrameInfo]: + """Captures the stack for a request. - return stack[::-1] + Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`. - def _get_stack_frame_method_name(self, frame_info: inspect.FrameInfo) -> str: - """Returns the name of a stack frame's method. + Used for debug logging. - eg. "JsonResource._async_render". - """ - method_name = frame_info.function - - # Prefix the class name for instance methods. - frame_self = frame_info.frame.f_locals.get("self") - if frame_self: - method = getattr(frame_self, method_name, None) - if method: - method_name = method.__qualname__ - else: - # We couldn't find the method on `self`. - # Make something up. It's useful to know which class "contains" a - # function anyway. - method_name = f"{type(frame_self).__name__} {method_name}" + Returns: + A list of `inspect.FrameInfo`s, with callers appearing before callees. + """ + stack = [] + + skip_frames += 1 # Also skip `get_stack` itself. + + for frame_info in inspect.stack()[skip_frames:]: + # Skip any twisted `inlineCallbacks` gunk. + if "/twisted/" in frame_info.filename: + continue + + # Exclude the reactor frame, upwards. + method_name = _get_stack_frame_method_name(frame_info) + if method_name == "ThreadedMemoryReactorClock.advance": + break + + stack.append(frame_info) + + # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the entry + # point for request handling. + if frame_info.function == "wrapped_async_request_handler": + break + + return stack[::-1] + + +def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: + """Returns the name of a stack frame's method. + + eg. "JsonResource._async_render". + """ + method_name = frame_info.function + + # Prefix the class name for instance methods. + frame_self = frame_info.frame.f_locals.get("self") + if frame_self: + method = getattr(frame_self, method_name, None) + if method: + method_name = method.__qualname__ + else: + # We couldn't find the method on `self`. + # Make something up. It's useful to know which class "contains" a + # function anyway. + method_name = f"{type(frame_self).__name__} {method_name}" + + return method_name - return method_name - def _hash_stack(self, stack: List[inspect.FrameInfo]): - """Turns a stack into a hashable value that can be put into a set.""" - return tuple(self._format_stack_frame(frame) for frame in stack) +def _hash_stack(stack: List[inspect.FrameInfo]): + """Turns a stack into a hashable value that can be put into a set.""" + return tuple(_format_stack_frame(frame) for frame in stack) From 74a6bc7df75b9e27a3df39ae02a18f395f24e863 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 16:23:57 +0100 Subject: [PATCH 12/25] Return a `FakeChannel` from `_test_cancellation_at_every_await` and let the caller test the status code --- tests/http/server/_base.py | 12 ++++-------- tests/rest/client/test_rooms.py | 24 ++++++++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index e7046975dcc8..3918e2102438 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -118,8 +118,7 @@ def _test_cancellation_at_every_await( reactor: ThreadedMemoryReactorClock, make_request: Callable[[], FakeChannel], test_name: str, - expected_code: int = HTTPStatus.OK, - ) -> JsonDict: + ) -> FakeChannel: """Performs a request repeatedly, disconnecting at successive `await`s, until one completes. @@ -137,7 +136,6 @@ def _test_cancellation_at_every_await( * A request gets stuck, possibly due to a previous cancellation. * The request does not return a 499 when the client disconnects. This implies that a `CancelledError` was swallowed somewhere. - * The request does not return a 200 when allowed to run to completion. It is up to the caller to verify that the request returns the correct data when it finally runs to completion. @@ -155,11 +153,10 @@ def _test_cancellation_at_every_await( make_request: A function that initiates the request and returns a `FakeChannel`. test_name: The name of the test, which will be logged. - expected_code: The expected status code for the final request that runs to - completion. Defaults to `200`. Returns: - The JSON response of the final request that runs to completion. + The `FakeChannel` object which stores the result of the final request that + runs to completion. """ # To process a request, a coroutine run is created for the async method handling # the request. That method may then start other coroutine runs, wrapped in @@ -236,8 +233,7 @@ def _test_cancellation_at_every_await( # the response. reactor.advance(1.0) - self.assertEqual(channel.code, expected_code) - return channel.json_body + return channel # Disconnect the client and wait for the response. request.connectionLost(reason=ConnectionDone()) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 16ed9356a689..734d992a4b6f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -485,7 +485,7 @@ class RoomStateTestCase(RoomBase, EndpointCancellationTestHelperMixin): def test_get_state_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state` request.""" room_id = self.helper.create_room_as(self.user_id) - body = self._test_cancellation_at_every_await( + channel = self._test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", "/rooms/%s/state" % room_id, await_result=False @@ -493,8 +493,9 @@ def test_get_state_cancellation(self) -> None: test_name="test_state_cancellation", ) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertCountEqual( - [state_event["type"] for state_event in body], + [state_event["type"] for state_event in channel.json_body], { "m.room.create", "m.room.power_levels", @@ -507,7 +508,7 @@ def test_get_state_cancellation(self) -> None: def test_get_state_event_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" room_id = self.helper.create_room_as(self.user_id) - body = self._test_cancellation_at_every_await( + channel = self._test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", @@ -517,7 +518,8 @@ def test_get_state_event_cancellation(self) -> None: test_name="test_state_cancellation", ) - self.assertEqual(body, {"membership": "join"}) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(channel.json_body, {"membership": "join"}) class RoomsMemberListTestCase(RoomBase, EndpointCancellationTestHelperMixin): @@ -643,7 +645,7 @@ def test_get_member_list_mixed_memberships(self) -> None: def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" room_id = self.helper.create_room_as(self.user_id) - body = self._test_cancellation_at_every_await( + channel = self._test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", "/rooms/%s/members" % room_id, await_result=False @@ -651,7 +653,8 @@ def test_get_member_list_cancellation(self) -> None: test_name="test_get_member_list_cancellation", ) - self.assertEqual(len(body["chunk"]), 1) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { "content": {"membership": "join"}, @@ -661,7 +664,7 @@ def test_get_member_list_cancellation(self) -> None: "type": "m.room.member", "user_id": self.user_id, }.items(), - body["chunk"][0].items(), + channel.json_body["chunk"][0].items(), ) def test_get_member_list_with_at_token_cancellation(self) -> None: @@ -673,7 +676,7 @@ def test_get_member_list_with_at_token_cancellation(self) -> None: self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] - body = self._test_cancellation_at_every_await( + channel = self._test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", @@ -683,7 +686,8 @@ def test_get_member_list_with_at_token_cancellation(self) -> None: test_name="test_get_member_list_with_at_token_cancellation", ) - self.assertEqual(len(body["chunk"]), 1) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { "content": {"membership": "join"}, @@ -693,7 +697,7 @@ def test_get_member_list_with_at_token_cancellation(self) -> None: "type": "m.room.member", "user_id": self.user_id, }.items(), - body["chunk"][0].items(), + channel.json_body["chunk"][0].items(), ) From 18dc8c0797177a7fd8e3457b65aadc348390a22d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 17:20:10 +0100 Subject: [PATCH 13/25] Turn `create_deferred___next___patch` and its return values into an object --- tests/http/server/_base.py | 174 +++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 83 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 3918e2102438..22765337591f 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -16,8 +16,20 @@ import itertools import logging from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + ContextManager, + Dict, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from unittest import mock +from unittest.mock import Mock from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone @@ -176,22 +188,13 @@ def _test_cancellation_at_every_await( ) for request_number in itertools.count(1): - ( - Deferred___next__, - unblock_awaits, - get_awaits_seen, - has_seen_new_await, - ) = self.create_deferred___next___patch(seen_awaits, request_number) + deferred_patch = Deferred__next__Patch(seen_awaits, request_number) try: with mock.patch( "synapse.http.server.respond_with_json", wraps=respond_with_json ) as respond_mock: - with mock.patch.object( - Deferred, - "__next__", - new=Deferred___next__, - ): + with deferred_patch.patch(): # Start the request. channel = make_request() request = channel.request @@ -209,17 +212,21 @@ def _test_cancellation_at_every_await( # Run the request until we see a new `await` which we have not # yet cancelled at, or it completes. - while not respond_mock.called and not has_seen_new_await(): - previous_awaits_seen = get_awaits_seen() + while ( + not respond_mock.called + and not deferred_patch.new_await_seen + ): + previous_awaits_seen = deferred_patch.awaits_seen reactor.advance(0.0) - if get_awaits_seen() == previous_awaits_seen: + if deferred_patch.awaits_seen == previous_awaits_seen: # We didn't see any progress. Try advancing the clock. reactor.advance(1.0) - if get_awaits_seen() == previous_awaits_seen: - # We still didn't see any progress. The request might be stuck. + if deferred_patch.awaits_seen == previous_awaits_seen: + # We still didn't see any progress. The request might be + # stuck. self.fail( "Request appears to be stuck, possibly due to " "a previous cancelled request" @@ -266,69 +273,61 @@ def _test_cancellation_at_every_await( self.assertEqual(code, HTTP_STATUS_REQUEST_CANCELLED) finally: # Unblock any processing that might be shared between requests. - unblock_awaits() + deferred_patch.unblock_awaits() assert False, "unreachable" # noqa: B011 - def create_deferred___next___patch( - self, seen_awaits: Set[Tuple[str, ...]], request_number: int - ) -> Tuple[ - Callable[["Deferred[T]"], "Deferred[T]"], - Callable[[], None], - Callable[[], int], - Callable[[], bool], - ]: - """Creates a function to patch `Deferred.__next__` with, for - `_test_cancellation_at_every_await`. - Produces a `Deferred.__next__` patch that will intercept `await`s and force them - to block once it sees a new `await`. +class Deferred__next__Patch: + """A `Deferred.__next__` patch that will intercept `await`s and force them + to block once it sees a new `await`. + + When done with the patch, `unblock_awaits()` must be called to clean up after any + `await`s that were forced to block, otherwise processing shared between multiple + requests, such as database queries started by `@cached`, will become permanently + stuck. + + Usage: + seen_awaits = set() + deferred_patch = Deferred__next__Patch(seen_awaits, 1) + try: + with deferred_patch.patch(): + # do things + ... + finally: + deferred_patch.unblock_awaits() + """ + def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int): + """ Args: seen_awaits: The set of stack traces of `await`s that have been previously seen. When the `Deferred.__next__` patch sees a new `await`, it will add it to the set. request_number: The request number to log against. - - Returns: - A tuple containing: - * The method to replace `Deferred.__next__` with. - * A method that will clean up after any `await`s that were forced to block. - This method must be called when done, otherwise processing shared between - multiple requests, such as database queries started by `@cached`, will - become permanently stuck. - * A method returning the running total of intercepted `await`s on - `Deferred`s. - * A method returning `True` once a new `await` has been seen. """ - original_Deferred___next__ = Deferred.__next__ + self._request_number = request_number + self._seen_awaits = seen_awaits + + self._original_Deferred___next__ = Deferred.__next__ # The number of `await`s on `Deferred`s we have seen so far. - awaits_seen = 0 + self.awaits_seen = 0 # Whether we have seen a new `await` not in `seen_awaits`. - new_await_seen = False + self.new_await_seen = False # To force `await`s on resolved `Deferred`s to block, we make up a new # unresolved `Deferred` and return it out of `Deferred.__next__` / - # `coroutine.send()`. We have to resolve it later, in case the - # `await`ing coroutine is part of some shared processing, such as - # `@cached`. - to_unblock: Dict[Deferred, Union[object, Failure]] = {} + # `coroutine.send()`. We have to resolve it later, in case the `await`ing + # coroutine is part of some shared processing, such as `@cached`. + self._to_unblock: Dict[Deferred, Union[object, Failure]] = {} # The last stack we logged. - previous_stack: List[inspect.FrameInfo] = [] + self._previous_stack: List[inspect.FrameInfo] = [] - def unblock_awaits() -> None: - """Unblocks any shared processing that we forced to block.""" - for deferred, result in to_unblock.items(): - deferred.callback(result) - - def get_awaits_seen() -> int: - return awaits_seen - - def has_seen_new_await() -> bool: - return new_await_seen + def patch(self) -> ContextManager[Mock]: + """Returns a context manager which patches `Deferred.__next__`.""" def Deferred___next__( deferred: "Deferred[T]", value: object = None @@ -339,59 +338,68 @@ def Deferred___next__( `Deferred.__next__` will normally: * return `self` if unresolved, which will come out of `coroutine.send()`. - * raise a `StopIteration(result)`, containing the result of the - `await`. + * raise a `StopIteration(result)`, containing the result of the `await`. * raise another exception, which will come out of the `await`. """ - nonlocal awaits_seen - nonlocal new_await_seen - nonlocal previous_stack - - awaits_seen += 1 + self.awaits_seen += 1 stack = _get_stack(skip_frames=1) stack_hash = _hash_stack(stack) - if stack_hash not in seen_awaits: + if stack_hash not in self._seen_awaits: # Block at the current `await` onwards. - seen_awaits.add(stack_hash) - new_await_seen = True + self._seen_awaits.add(stack_hash) + self.new_await_seen = True - if not new_await_seen: + if not self.new_await_seen: # This `await` isn't interesting. Let it proceed normally. # Don't log the stack. It's been seen before in a previous run. - previous_stack = stack + self._previous_stack = stack - return original_Deferred___next__(deferred, value) + return self._original_Deferred___next__(deferred, value) else: # We want to block at the current `await`. if deferred.called and not deferred.paused: # This `Deferred` already has a result. - # We return a new, unresolved, `Deferred` for `_inlineCallbacks` - # to wait on. This blocks the coroutine that did this `await`. + # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to + # wait on. This blocks the coroutine that did this `await`. # We queue it up for unblocking later. new_deferred: "Deferred[T]" = Deferred() - to_unblock[new_deferred] = deferred.result + self._to_unblock[new_deferred] = deferred.result _log_await_stack( - stack, previous_stack, request_number, "force-blocked await" + stack, + self._previous_stack, + self._request_number, + "force-blocked await", ) - previous_stack = stack + self._previous_stack = stack return make_deferred_yieldable(new_deferred) else: # This `Deferred` does not have a result yet. - # The `await` will block normally, so we don't have to do - # anything. + # The `await` will block normally, so we don't have to do anything. _log_await_stack( - stack, previous_stack, request_number, "blocking await" + stack, + self._previous_stack, + self._request_number, + "blocking await", ) - previous_stack = stack + self._previous_stack = stack - return original_Deferred___next__(deferred, value) + return self._original_Deferred___next__(deferred, value) - return Deferred___next__, unblock_awaits, get_awaits_seen, has_seen_new_await + return mock.patch.object(Deferred, "__next__", new=Deferred___next__) + + def unblock_awaits(self) -> None: + """Unblocks any shared processing that we forced to block. + + Must be called when done, otherwise processing shared between multiple requests, + such as database queries started by `@cached`, will become permanently stuck. + """ + for deferred, result in self._to_unblock.items(): + deferred.callback(result) def _log_for_request(request_number: int, message: str) -> None: From 9955d92c068a6c913a5cea08d5673211ffbf473d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 18:05:12 +0100 Subject: [PATCH 14/25] Raise `AssertionError`s ourselves --- tests/http/server/_base.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 22765337591f..342c12d24437 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -199,9 +199,8 @@ def _test_cancellation_at_every_await( channel = make_request() request = channel.request - if request_number == 1: - self.assertFalse( - respond_mock.called, + if request_number == 1 and respond_mock.called: + raise AssertionError( "Request finished before we could disconnect - ensure " "`await_result=False` is passed to `make_request`.", ) @@ -227,9 +226,9 @@ def _test_cancellation_at_every_await( if deferred_patch.awaits_seen == previous_awaits_seen: # We still didn't see any progress. The request might be # stuck. - self.fail( - "Request appears to be stuck, possibly due to " - "a previous cancelled request" + raise AssertionError( + "Request appears to be stuck, possibly due to a " + "previous cancelled request" ) if respond_mock.called: @@ -270,7 +269,12 @@ def _test_cancellation_at_every_await( respond_mock.assert_called_once() args, _kwargs = respond_mock.call_args code = args[1] - self.assertEqual(code, HTTP_STATUS_REQUEST_CANCELLED) + + if code != HTTP_STATUS_REQUEST_CANCELLED: + raise AssertionError( + f"{code} != {HTTP_STATUS_REQUEST_CANCELLED} : Cancelled " + "request did not finish with the correct status code." + ) finally: # Unblock any processing that might be shared between requests. deferred_patch.unblock_awaits() From 32dc9338abbf26e01a580b71263bf9f5f49174f4 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 18:19:22 +0100 Subject: [PATCH 15/25] Un-instance method `_test_cancellation_at_every_await` --- tests/http/server/_base.py | 293 ++++++++++++++++---------------- tests/rest/client/test_rooms.py | 14 +- 2 files changed, 151 insertions(+), 156 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 342c12d24437..7caebaf25ec5 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -124,162 +124,157 @@ def _test_disconnect( self.assertEqual(request.code, expected_code) self.assertEqual(body, expected_body) - @logcontext_clean - def _test_cancellation_at_every_await( - self, - reactor: ThreadedMemoryReactorClock, - make_request: Callable[[], FakeChannel], - test_name: str, - ) -> FakeChannel: - """Performs a request repeatedly, disconnecting at successive `await`s, until - one completes. - - Fails if: - * A logging context is lost during cancellation. - * A logging context get restarted after it is marked as finished, eg. if - a request's logging context is used by some processing started by the - request, but the request neglects to cancel that processing or wait for it - to complete. - - Note that "Re-starting finished log context" errors get caught by twisted - and will manifest in a different logging context error at a later point. - When debugging logging context failures, setting a breakpoint in - `logcontext_error` can prove useful. - * A request gets stuck, possibly due to a previous cancellation. - * The request does not return a 499 when the client disconnects. - This implies that a `CancelledError` was swallowed somewhere. - - It is up to the caller to verify that the request returns the correct data when - it finally runs to completion. - - Note that this function can only cover a single code path and does not guarantee - that an endpoint is compatible with cancellation on every code path. - To allow inspection of the code path that is being tested, this function will - log the stack trace at every `await` that gets cancelled. To view these log - lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment - variable, which will include the log lines in `_trial_temp/test.log`. - Alternatively, `_log_for_request` can be modified to write to `sys.stdout`. - Args: - reactor: The twisted reactor running the request handler. - make_request: A function that initiates the request and returns a - `FakeChannel`. - test_name: The name of the test, which will be logged. +@logcontext_clean +def test_cancellation_at_every_await( + reactor: ThreadedMemoryReactorClock, + make_request: Callable[[], FakeChannel], + test_name: str, +) -> FakeChannel: + """Performs a request repeatedly, disconnecting at successive `await`s, until + one completes. + + Fails if: + * A logging context is lost during cancellation. + * A logging context get restarted after it is marked as finished, eg. if + a request's logging context is used by some processing started by the + request, but the request neglects to cancel that processing or wait for it + to complete. + + Note that "Re-starting finished log context" errors get caught by twisted + and will manifest in a different logging context error at a later point. + When debugging logging context failures, setting a breakpoint in + `logcontext_error` can prove useful. + * A request gets stuck, possibly due to a previous cancellation. + * The request does not return a 499 when the client disconnects. + This implies that a `CancelledError` was swallowed somewhere. + + It is up to the caller to verify that the request returns the correct data when + it finally runs to completion. + + Note that this function can only cover a single code path and does not guarantee + that an endpoint is compatible with cancellation on every code path. + To allow inspection of the code path that is being tested, this function will + log the stack trace at every `await` that gets cancelled. To view these log + lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment + variable, which will include the log lines in `_trial_temp/test.log`. + Alternatively, `_log_for_request` can be modified to write to `sys.stdout`. - Returns: - The `FakeChannel` object which stores the result of the final request that - runs to completion. - """ - # To process a request, a coroutine run is created for the async method handling - # the request. That method may then start other coroutine runs, wrapped in - # `Deferred`s. - # - # We would like to trigger a cancellation at the first `await`, re-run the - # request and cancel at the second `await`, and so on. By patching - # `Deferred.__next__`, we can intercept `await`s, track which ones we have or - # have not seen, and force them to block when they wouldn't have. - - # The set of previously seen `await`s. - # Each element is a stringified stack trace. - seen_awaits: Set[Tuple[str, ...]] = set() - - _log_for_request( - 0, f"Running _test_cancellation_at_every_await for {test_name}..." - ) + Args: + reactor: The twisted reactor running the request handler. + make_request: A function that initiates the request and returns a + `FakeChannel`. + test_name: The name of the test, which will be logged. + + Returns: + The `FakeChannel` object which stores the result of the final request that + runs to completion. + """ + # To process a request, a coroutine run is created for the async method handling + # the request. That method may then start other coroutine runs, wrapped in + # `Deferred`s. + # + # We would like to trigger a cancellation at the first `await`, re-run the + # request and cancel at the second `await`, and so on. By patching + # `Deferred.__next__`, we can intercept `await`s, track which ones we have or + # have not seen, and force them to block when they wouldn't have. - for request_number in itertools.count(1): - deferred_patch = Deferred__next__Patch(seen_awaits, request_number) + # The set of previously seen `await`s. + # Each element is a stringified stack trace. + seen_awaits: Set[Tuple[str, ...]] = set() - try: - with mock.patch( - "synapse.http.server.respond_with_json", wraps=respond_with_json - ) as respond_mock: - with deferred_patch.patch(): - # Start the request. - channel = make_request() - request = channel.request + _log_for_request(0, f"Running test_cancellation_at_every_await for {test_name}...") - if request_number == 1 and respond_mock.called: - raise AssertionError( - "Request finished before we could disconnect - ensure " - "`await_result=False` is passed to `make_request`.", - ) - else: - # Requests after the first may be lucky enough to hit caches - # all the way through and never have to block. - pass - - # Run the request until we see a new `await` which we have not - # yet cancelled at, or it completes. - while ( - not respond_mock.called - and not deferred_patch.new_await_seen - ): - previous_awaits_seen = deferred_patch.awaits_seen - - reactor.advance(0.0) - - if deferred_patch.awaits_seen == previous_awaits_seen: - # We didn't see any progress. Try advancing the clock. - reactor.advance(1.0) - - if deferred_patch.awaits_seen == previous_awaits_seen: - # We still didn't see any progress. The request might be - # stuck. - raise AssertionError( - "Request appears to be stuck, possibly due to a " - "previous cancelled request" - ) - - if respond_mock.called: - # The request ran to completion and we are done with testing it. - - # `respond_with_json` writes the response asynchronously, so we - # might have to give the reactor a kick before the channel gets - # the response. - reactor.advance(1.0) - - return channel - - # Disconnect the client and wait for the response. - request.connectionLost(reason=ConnectionDone()) - - _log_for_request(request_number, "--- disconnected ---") - - # We may need to pump the reactor to allow `delay_cancellation`s to - # finish. - if not respond_mock.called: - reactor.advance(0.0) + for request_number in itertools.count(1): + deferred_patch = Deferred__next__Patch(seen_awaits, request_number) - # Try advancing the clock if that didn't work. - if not respond_mock.called: - reactor.advance(1.0) - - # Mark the request's logging context as finished. If it gets - # activated again, an `AssertionError` will be raised. This - # `AssertionError` will likely be caught by twisted and turned into - # a `Failure`. Instead, a different `AssertionError` will be - # observed when the logging context is deactivated, as it wouldn't - # have tracked resource usage correctly. - if isinstance(request, SynapseRequest) and request.logcontext: - request.logcontext.finished = True - - # Check that the request finished with a 499, - # ie. the `CancelledError` wasn't swallowed. - respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code = args[1] - - if code != HTTP_STATUS_REQUEST_CANCELLED: + try: + with mock.patch( + "synapse.http.server.respond_with_json", wraps=respond_with_json + ) as respond_mock: + with deferred_patch.patch(): + # Start the request. + channel = make_request() + request = channel.request + + if request_number == 1 and respond_mock.called: raise AssertionError( - f"{code} != {HTTP_STATUS_REQUEST_CANCELLED} : Cancelled " - "request did not finish with the correct status code." + "Request finished before we could disconnect - ensure " + "`await_result=False` is passed to `make_request`.", ) - finally: - # Unblock any processing that might be shared between requests. - deferred_patch.unblock_awaits() + else: + # Requests after the first may be lucky enough to hit caches + # all the way through and never have to block. + pass + + # Run the request until we see a new `await` which we have not + # yet cancelled at, or it completes. + while not respond_mock.called and not deferred_patch.new_await_seen: + previous_awaits_seen = deferred_patch.awaits_seen + + reactor.advance(0.0) + + if deferred_patch.awaits_seen == previous_awaits_seen: + # We didn't see any progress. Try advancing the clock. + reactor.advance(1.0) + + if deferred_patch.awaits_seen == previous_awaits_seen: + # We still didn't see any progress. The request might be + # stuck. + raise AssertionError( + "Request appears to be stuck, possibly due to a " + "previous cancelled request" + ) + + if respond_mock.called: + # The request ran to completion and we are done with testing it. + + # `respond_with_json` writes the response asynchronously, so we + # might have to give the reactor a kick before the channel gets + # the response. + reactor.advance(1.0) + + return channel + + # Disconnect the client and wait for the response. + request.connectionLost(reason=ConnectionDone()) + + _log_for_request(request_number, "--- disconnected ---") + + # We may need to pump the reactor to allow `delay_cancellation`s to + # finish. + if not respond_mock.called: + reactor.advance(0.0) + + # Try advancing the clock if that didn't work. + if not respond_mock.called: + reactor.advance(1.0) + + # Mark the request's logging context as finished. If it gets + # activated again, an `AssertionError` will be raised. This + # `AssertionError` will likely be caught by twisted and turned into + # a `Failure`. Instead, a different `AssertionError` will be + # observed when the logging context is deactivated, as it wouldn't + # have tracked resource usage correctly. + if isinstance(request, SynapseRequest) and request.logcontext: + request.logcontext.finished = True + + # Check that the request finished with a 499, + # ie. the `CancelledError` wasn't swallowed. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code = args[1] + + if code != HTTP_STATUS_REQUEST_CANCELLED: + raise AssertionError( + f"{code} != {HTTP_STATUS_REQUEST_CANCELLED} : Cancelled " + "request did not finish with the correct status code." + ) + finally: + # Unblock any processing that might be shared between requests. + deferred_patch.unblock_awaits() - assert False, "unreachable" # noqa: B011 + assert False, "unreachable" # noqa: B011 class Deferred__next__Patch: @@ -407,7 +402,7 @@ def unblock_awaits(self) -> None: def _log_for_request(request_number: int, message: str) -> None: - """Logs a message for an iteration of `_test_cancellation_at_every_await`.""" + """Logs a message for an iteration of `test_cancellation_at_every_await`.""" # We want consistent alignment when logging stack traces, so ensure the logging # context has a fixed width name. with LoggingContext(name=f"request-{request_number:<2}"): @@ -420,7 +415,7 @@ def _log_await_stack( request_number: int, note: str, ) -> None: - """Logs the stack for an `await` in `_test_cancellation_at_every_await`. + """Logs the stack for an `await` in `test_cancellation_at_every_await`. Only logs the part of the stack that has changed since the previous call. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 734d992a4b6f..6cab6e3dab66 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -42,7 +42,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_cancellation_at_every_await from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -477,7 +477,7 @@ def test_member_event_from_ban(self) -> None: ) -class RoomStateTestCase(RoomBase, EndpointCancellationTestHelperMixin): +class RoomStateTestCase(RoomBase): """Tests /rooms/$room_id/state.""" user_id = "@sid1:red" @@ -485,7 +485,7 @@ class RoomStateTestCase(RoomBase, EndpointCancellationTestHelperMixin): def test_get_state_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = self._test_cancellation_at_every_await( + channel = test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", "/rooms/%s/state" % room_id, await_result=False @@ -508,7 +508,7 @@ def test_get_state_cancellation(self) -> None: def test_get_state_event_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = self._test_cancellation_at_every_await( + channel = test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", @@ -522,7 +522,7 @@ def test_get_state_event_cancellation(self) -> None: self.assertEqual(channel.json_body, {"membership": "join"}) -class RoomsMemberListTestCase(RoomBase, EndpointCancellationTestHelperMixin): +class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" servlets = RoomBase.servlets + [sync.register_servlets] @@ -645,7 +645,7 @@ def test_get_member_list_mixed_memberships(self) -> None: def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = self._test_cancellation_at_every_await( + channel = test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", "/rooms/%s/members" % room_id, await_result=False @@ -676,7 +676,7 @@ def test_get_member_list_with_at_token_cancellation(self) -> None: self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] - channel = self._test_cancellation_at_every_await( + channel = test_cancellation_at_every_await( self.reactor, lambda: self.make_request( "GET", From 633c48f6b90c0080751aaa3aec2b9ad9dc325d6a Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 18:20:02 +0100 Subject: [PATCH 16/25] Revert "Fix mypy thinking `RoomBase.servlets` is a `List[function]`" This reverts commit 92045c8a62ee9eee02e946fa412a464ea62bec42. --- tests/rest/client/test_rooms.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 6cab6e3dab66..5278e4e49df0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, ClassVar, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse @@ -34,7 +34,7 @@ ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus -from synapse.rest import RegisterServletsFunc, admin +from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester @@ -51,12 +51,7 @@ class RoomBase(unittest.HomeserverTestCase): rmcreator_id: Optional[str] = None - # mypy: `room.register_servlets` has an extra parameter, so mypy thinks `servlets` - # is a `List[function]` without the hint. - servlets: ClassVar[List[RegisterServletsFunc]] = [ - room.register_servlets, - room.register_deprecated_servlets, - ] + servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: From 2f4aeee6c0c025c7599e2af3594990f0edb9a4e0 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 31 May 2022 19:28:02 +0100 Subject: [PATCH 17/25] Rename `test_cancellation_at_every_await` to `make_request_with_cancellation_test` and have it make the request --- tests/http/server/_base.py | 34 ++++++++++++------------ tests/rest/client/test_rooms.py | 46 +++++++++++++++------------------ 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 7caebaf25ec5..f18ae67f8cf2 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -34,6 +34,7 @@ from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone from twisted.python.failure import Failure +from twisted.web.server import Site from synapse.http.server import ( HTTP_STATUS_REQUEST_CANCELLED, @@ -45,7 +46,7 @@ from synapse.types import JsonDict from tests import unittest -from tests.server import FakeChannel, ThreadedMemoryReactorClock +from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request from tests.unittest import logcontext_clean logger = logging.getLogger(__name__) @@ -126,10 +127,13 @@ def _test_disconnect( @logcontext_clean -def test_cancellation_at_every_await( - reactor: ThreadedMemoryReactorClock, - make_request: Callable[[], FakeChannel], +def make_request_with_cancellation_test( test_name: str, + reactor: ThreadedMemoryReactorClock, + site: Site, + method: str, + path: str, + content: Union[bytes, str, JsonDict] = b"", ) -> FakeChannel: """Performs a request repeatedly, disconnecting at successive `await`s, until one completes. @@ -183,7 +187,9 @@ def test_cancellation_at_every_await( # Each element is a stringified stack trace. seen_awaits: Set[Tuple[str, ...]] = set() - _log_for_request(0, f"Running test_cancellation_at_every_await for {test_name}...") + _log_for_request( + 0, f"Running make_request_with_cancellation_test for {test_name}..." + ) for request_number in itertools.count(1): deferred_patch = Deferred__next__Patch(seen_awaits, request_number) @@ -194,19 +200,11 @@ def test_cancellation_at_every_await( ) as respond_mock: with deferred_patch.patch(): # Start the request. - channel = make_request() + channel = make_request( + reactor, site, method, path, content, await_result=False + ) request = channel.request - if request_number == 1 and respond_mock.called: - raise AssertionError( - "Request finished before we could disconnect - ensure " - "`await_result=False` is passed to `make_request`.", - ) - else: - # Requests after the first may be lucky enough to hit caches - # all the way through and never have to block. - pass - # Run the request until we see a new `await` which we have not # yet cancelled at, or it completes. while not respond_mock.called and not deferred_patch.new_await_seen: @@ -402,7 +400,7 @@ def unblock_awaits(self) -> None: def _log_for_request(request_number: int, message: str) -> None: - """Logs a message for an iteration of `test_cancellation_at_every_await`.""" + """Logs a message for an iteration of `make_request_with_cancellation_test`.""" # We want consistent alignment when logging stack traces, so ensure the logging # context has a fixed width name. with LoggingContext(name=f"request-{request_number:<2}"): @@ -415,7 +413,7 @@ def _log_await_stack( request_number: int, note: str, ) -> None: - """Logs the stack for an `await` in `test_cancellation_at_every_await`. + """Logs the stack for an `await` in `make_request_with_cancellation_test`. Only logs the part of the stack that has changed since the previous call. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 5278e4e49df0..4be83dfd6d8f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -42,7 +42,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.http.server._base import test_cancellation_at_every_await +from tests.http.server._base import make_request_with_cancellation_test from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -480,12 +480,12 @@ class RoomStateTestCase(RoomBase): def test_get_state_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = test_cancellation_at_every_await( + channel = make_request_with_cancellation_test( + "test_state_cancellation", self.reactor, - lambda: self.make_request( - "GET", "/rooms/%s/state" % room_id, await_result=False - ), - test_name="test_state_cancellation", + self.site, + "GET", + "/rooms/%s/state" % room_id, ) self.assertEqual(200, channel.code, msg=channel.result["body"]) @@ -503,14 +503,12 @@ def test_get_state_cancellation(self) -> None: def test_get_state_event_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = test_cancellation_at_every_await( + channel = make_request_with_cancellation_test( + "test_state_cancellation", self.reactor, - lambda: self.make_request( - "GET", - "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), - await_result=False, - ), - test_name="test_state_cancellation", + self.site, + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), ) self.assertEqual(200, channel.code, msg=channel.result["body"]) @@ -640,12 +638,12 @@ def test_get_member_list_mixed_memberships(self) -> None: def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" room_id = self.helper.create_room_as(self.user_id) - channel = test_cancellation_at_every_await( + channel = make_request_with_cancellation_test( + "test_get_member_list_cancellation", self.reactor, - lambda: self.make_request( - "GET", "/rooms/%s/members" % room_id, await_result=False - ), - test_name="test_get_member_list_cancellation", + self.site, + "GET", + "/rooms/%s/members" % room_id, ) self.assertEqual(200, channel.code, msg=channel.result["body"]) @@ -671,14 +669,12 @@ def test_get_member_list_with_at_token_cancellation(self) -> None: self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] - channel = test_cancellation_at_every_await( + channel = make_request_with_cancellation_test( + "test_get_member_list_with_at_token_cancellation", self.reactor, - lambda: self.make_request( - "GET", - "/rooms/%s/members?at=%s" % (room_id, sync_token), - await_result=False, - ), - test_name="test_get_member_list_with_at_token_cancellation", + self.site, + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), ) self.assertEqual(200, channel.code, msg=channel.result["body"]) From fa6245a74c2cb809e4a800b0ebbe53000b759ea4 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:35:37 +0100 Subject: [PATCH 18/25] Fix out-of-date docstring args --- tests/http/server/_base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index f18ae67f8cf2..a1f1a6ea73d6 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -165,10 +165,13 @@ def make_request_with_cancellation_test( Alternatively, `_log_for_request` can be modified to write to `sys.stdout`. Args: - reactor: The twisted reactor running the request handler. - make_request: A function that initiates the request and returns a - `FakeChannel`. test_name: The name of the test, which will be logged. + reactor: The twisted reactor running the request handler. + site: The twisted `Site` to use to render the request. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and + such). + content: The body of the request. Returns: The `FakeChannel` object which stores the result of the final request that From 3a73d6f9e0e5c2c28fc2e224ec43892862b7d112 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:37:51 +0100 Subject: [PATCH 19/25] Replace `ThreadedMemoryReactorClock` with `MemoryReactorClock` --- tests/http/server/_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index a1f1a6ea73d6..f6d52697abdf 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -34,6 +34,7 @@ from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.server import Site from synapse.http.server import ( @@ -129,7 +130,7 @@ def _test_disconnect( @logcontext_clean def make_request_with_cancellation_test( test_name: str, - reactor: ThreadedMemoryReactorClock, + reactor: MemoryReactorClock, site: Site, method: str, path: str, From 39ed2b351fae1180090e753323ec3d4bd5f005b3 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:39:49 +0100 Subject: [PATCH 20/25] Update docstring about `Deferred.__next__`'s behaviour --- tests/http/server/_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index f6d52697abdf..827358ad5959 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -337,8 +337,10 @@ def Deferred___next__( seen enough of them. `Deferred.__next__` will normally: - * return `self` if unresolved, which will come out of - `coroutine.send()`. + * return `self` if the `Deferred` is unresolved, in which case + `coroutine.send()` will return the `Deferred`, and + `_defer.inlineCallbacks` will stop running the coroutine until the + `Deferred` is resolved. * raise a `StopIteration(result)`, containing the result of the `await`. * raise another exception, which will come out of the `await`. """ From 1ff93917d34f13f38cecc509b242d06129d81b5d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:42:20 +0100 Subject: [PATCH 21/25] Outdent the if blocks in the `Deferred.__next__` patch --- tests/http/server/_base.py | 60 +++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 827358ad5959..46816c88cd18 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -361,37 +361,37 @@ def Deferred___next__( self._previous_stack = stack return self._original_Deferred___next__(deferred, value) - else: - # We want to block at the current `await`. - if deferred.called and not deferred.paused: - # This `Deferred` already has a result. - # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to - # wait on. This blocks the coroutine that did this `await`. - # We queue it up for unblocking later. - new_deferred: "Deferred[T]" = Deferred() - self._to_unblock[new_deferred] = deferred.result - - _log_await_stack( - stack, - self._previous_stack, - self._request_number, - "force-blocked await", - ) - self._previous_stack = stack - - return make_deferred_yieldable(new_deferred) - else: - # This `Deferred` does not have a result yet. - # The `await` will block normally, so we don't have to do anything. - _log_await_stack( - stack, - self._previous_stack, - self._request_number, - "blocking await", - ) - self._previous_stack = stack - return self._original_Deferred___next__(deferred, value) + # We want to block at the current `await`. + if deferred.called and not deferred.paused: + # This `Deferred` already has a result. + # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait + # on. This blocks the coroutine that did this `await`. + # We queue it up for unblocking later. + new_deferred: "Deferred[T]" = Deferred() + self._to_unblock[new_deferred] = deferred.result + + _log_await_stack( + stack, + self._previous_stack, + self._request_number, + "force-blocked await", + ) + self._previous_stack = stack + + return make_deferred_yieldable(new_deferred) + + # This `Deferred` does not have a result yet. + # The `await` will block normally, so we don't have to do anything. + _log_await_stack( + stack, + self._previous_stack, + self._request_number, + "blocking await", + ) + self._previous_stack = stack + + return self._original_Deferred___next__(deferred, value) return mock.patch.object(Deferred, "__next__", new=Deferred___next__) From 9ebcf3898bafbc614da8fb472c7c806ceba663ae Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:44:20 +0100 Subject: [PATCH 22/25] Use `channel.await_result()` instead of `reactor.advance()` on final request --- tests/http/server/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 46816c88cd18..8667cee88ceb 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -234,7 +234,7 @@ def make_request_with_cancellation_test( # `respond_with_json` writes the response asynchronously, so we # might have to give the reactor a kick before the channel gets # the response. - reactor.advance(1.0) + channel.await_result() return channel From 2a5b800610c2413b870245b0bdaf0ee78abe89ad Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 15:47:24 +0100 Subject: [PATCH 23/25] Take the status code from the `Request` instead of the `respond_with_json` mock --- tests/http/server/_base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 8667cee88ceb..6830b3273438 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -264,13 +264,11 @@ def make_request_with_cancellation_test( # Check that the request finished with a 499, # ie. the `CancelledError` wasn't swallowed. respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code = args[1] - if code != HTTP_STATUS_REQUEST_CANCELLED: + if request.code != HTTP_STATUS_REQUEST_CANCELLED: raise AssertionError( - f"{code} != {HTTP_STATUS_REQUEST_CANCELLED} : Cancelled " - "request did not finish with the correct status code." + f"{request.code} != {HTTP_STATUS_REQUEST_CANCELLED} : " + "Cancelled request did not finish with the correct status code." ) finally: # Unblock any processing that might be shared between requests. From a3e9ccee6b5d8ce3b8112ad05403804f11be140e Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 16:07:51 +0100 Subject: [PATCH 24/25] Reword comments about re-starting of finished logging contexts --- tests/http/server/_base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 6830b3273438..ae3a56446b2d 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -146,9 +146,10 @@ def make_request_with_cancellation_test( request, but the request neglects to cancel that processing or wait for it to complete. - Note that "Re-starting finished log context" errors get caught by twisted - and will manifest in a different logging context error at a later point. - When debugging logging context failures, setting a breakpoint in + Note that "Re-starting finished log context" errors get raised within the + request handling code and may or may not get caught. These errors will + likely manifest as a different logging context error at a later point. When + debugging logging context failures, setting a breakpoint in `logcontext_error` can prove useful. * A request gets stuck, possibly due to a previous cancellation. * The request does not return a 499 when the client disconnects. @@ -253,11 +254,11 @@ def make_request_with_cancellation_test( reactor.advance(1.0) # Mark the request's logging context as finished. If it gets - # activated again, an `AssertionError` will be raised. This - # `AssertionError` will likely be caught by twisted and turned into - # a `Failure`. Instead, a different `AssertionError` will be - # observed when the logging context is deactivated, as it wouldn't - # have tracked resource usage correctly. + # activated again, an `AssertionError` will be raised and bubble up + # through request handling code. This `AssertionError` may or may not be + # caught. Eventually some other code will deactivate the logging + # context which will raise a different `AssertionError` because + # resource usage won't have been correctly tracked. if isinstance(request, SynapseRequest) and request.logcontext: request.logcontext.finished = True From 1223c2daf3094062bd78323a29e050fc3eb869ff Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 1 Jun 2022 19:11:36 +0100 Subject: [PATCH 25/25] When waiting for a response after a disconnect, try unblocking awaits that we forced to block --- tests/http/server/_base.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index ae3a56446b2d..57b92beb8721 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -235,6 +235,7 @@ def make_request_with_cancellation_test( # `respond_with_json` writes the response asynchronously, so we # might have to give the reactor a kick before the channel gets # the response. + deferred_patch.unblock_awaits() channel.await_result() return channel @@ -244,14 +245,25 @@ def make_request_with_cancellation_test( _log_for_request(request_number, "--- disconnected ---") - # We may need to pump the reactor to allow `delay_cancellation`s to - # finish. - if not respond_mock.called: - reactor.advance(0.0) + # Advance the reactor just enough to get a response. + # We don't want to advance the reactor too far, because we can only + # detect re-starts of finished logging contexts after we set the + # finished flag below. + for _ in range(2): + # We may need to pump the reactor to allow `delay_cancellation`s to + # finish. + if not respond_mock.called: + reactor.advance(0.0) + + # Try advancing the clock if that didn't work. + if not respond_mock.called: + reactor.advance(1.0) - # Try advancing the clock if that didn't work. - if not respond_mock.called: - reactor.advance(1.0) + # `delay_cancellation`s may be waiting for processing that we've + # forced to block. Try unblocking them, followed by another round of + # pumping the reactor. + if not respond_mock.called: + deferred_patch.unblock_awaits() # Mark the request's logging context as finished. If it gets # activated again, an `AssertionError` will be raised and bubble up @@ -272,7 +284,8 @@ def make_request_with_cancellation_test( "Cancelled request did not finish with the correct status code." ) finally: - # Unblock any processing that might be shared between requests. + # Unblock any processing that might be shared between requests, if we + # haven't already done so. deferred_patch.unblock_awaits() assert False, "unreachable" # noqa: B011 @@ -400,7 +413,9 @@ def unblock_awaits(self) -> None: Must be called when done, otherwise processing shared between multiple requests, such as database queries started by `@cached`, will become permanently stuck. """ - for deferred, result in self._to_unblock.items(): + to_unblock = self._to_unblock + self._to_unblock = {} + for deferred, result in to_unblock.items(): deferred.callback(result)