Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clean up the test code for client disconnections #12929

Merged
merged 7 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12929.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up the test code for client disconnection.
10 changes: 4 additions & 6 deletions tests/federation/transport/server/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from synapse.util.ratelimitutils import FederationRateLimiter

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


class CancellableFederationServlet(BaseFederationServlet):
Expand Down Expand Up @@ -54,9 +54,7 @@ async def on_POST(
return HTTPStatus.OK, {"result": True}


class BaseFederationServletCancellationTests(
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
):
class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""

skip = "`BaseFederationServlet` does not support cancellation yet."
Expand Down Expand Up @@ -86,7 +84,7 @@ def test_cancellable_disconnect(self) -> None:
# request won't be processed.
self.pump()

self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -106,7 +104,7 @@ def test_uncancellable_disconnect(self) -> None:
# request won't be processed.
self.pump()

self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
132 changes: 69 additions & 63 deletions tests/http/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict

from tests import unittest
from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
from tests.server import FakeChannel, make_request
from tests.unittest import logcontext_clean

logger = logging.getLogger(__name__)
Expand All @@ -56,75 +55,82 @@
T = TypeVar("T")


class EndpointCancellationTestHelperMixin(unittest.TestCase):
"""Provides helper methods for testing cancellation of endpoints."""
def test_disconnect(
reactor: MemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.

def _test_disconnect(
self,
reactor: ThreadedMemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.
Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be cancelled,
`False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200` or
`499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK

Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be
cancelled, `False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200`
or `499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK

request = channel.request
self.assertFalse(
channel.is_finished(),
request = channel.request
if channel.is_finished():
raise AssertionError(
"Request finished before we could disconnect - "
"was `await_result=False` passed to `make_request`?",
"ensure `await_result=False` is passed to `make_request`.",
)

# We're about to disconnect the request. This also disconnects the channel, so
# we have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
# We're about to disconnect the request. This also disconnects the channel, so we
# have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
else:
respond_method = respond_with_json

with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())

if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
else:
respond_method = respond_with_json
respond_mock.assert_not_called()

with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())
# The handler is expected to run to completion.
reactor.advance(1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we can't use channel.await_result() here because we've already disconnected the channel above.

respond_mock.assert_called_once()

if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
else:
respond_mock.assert_not_called()

# The handler is expected to run to completion.
reactor.pump([1.0])
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]

if code != expected_code:
raise AssertionError(
f"{code} != {expected_code} : "
"Request did not finish with the expected status code."
)

if request.code != expected_code:
raise AssertionError(
f"{request.code} != {expected_code} : "
"Request did not finish with the expected status code."
)

if body != expected_body:
raise AssertionError(
f"{body!r} != {expected_body!r} : "
"Request did not finish with the expected status code."
)


@logcontext_clean
Expand Down
10 changes: 4 additions & 6 deletions tests/http/test_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from synapse.types import JsonDict

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


def make_request(content):
Expand Down Expand Up @@ -108,9 +108,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return HTTPStatus.OK, {"result": True}


class TestRestServletCancellation(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""

servlets = [
Expand All @@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -130,7 +128,7 @@ def test_cancellable_disconnect(self) -> None:
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
10 changes: 4 additions & 6 deletions tests/replication/http/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.types import JsonDict

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


class CancellableReplicationEndpoint(ReplicationEndpoint):
Expand Down Expand Up @@ -69,9 +69,7 @@ async def _handle_request( # type: ignore[override]
return HTTPStatus.OK, {"result": True}


class ReplicationEndpointCancellationTestCase(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""

def create_test_resource(self):
Expand All @@ -87,7 +85,7 @@ def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -98,7 +96,7 @@ def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from synapse.util import Clock

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
Expand Down Expand Up @@ -407,7 +407,7 @@ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]
return HTTPStatus.OK, b"ok"


class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""

def setUp(self):
Expand All @@ -421,7 +421,7 @@ def test_cancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -433,15 +433,15 @@ def test_uncancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)


class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""

def setUp(self):
Expand All @@ -455,7 +455,7 @@ def test_cancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -467,6 +467,6 @@ def test_uncancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)