From 2ff6425ce065f4550df5287e96e51235c8f8aa5f Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 19 Apr 2022 20:11:37 +0100 Subject: [PATCH 1/9] Support `DoneAwaitable` and `make_awaitable` in `run_in_background` Signed-off-by: Sean Quah --- synapse/logging/context.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 88cd8a9e1c39..38cb1cc0e124 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -808,10 +808,21 @@ def run_in_background( # type: ignore[misc] # At this point we should have a Deferred, if not then f was a synchronous # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): - # `res` is not a `Deferred` and not a `Coroutine`. - # There are no other types of `Awaitable`s we expect to encounter in Synapse. - assert not isinstance(res, Awaitable) + if isinstance(res, Awaitable): + # `f` returned some kind of awaitable that is not a coroutine or `Deferred`. + # We assume that it is a completed awaitable, such as a `DoneAwaitable` or + # `Future` from `make_awaitable`, and await it manually. + iterator = res.__await__() # `__await__` returns an iterator... + try: + next(iterator) + raise ValueError( + f"Function {f} returned an unresolved awaitable: {res}" + ) + except StopIteration as e: + # ...which raises a `StopIteration` once the awaitable is complete. + res = e.value + # res is now a plain value here. return defer.succeed(res) if res.called and not res.paused: From d56d118bdfe8d0712c1d8971d433a82ffe4d28cf Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 19 Apr 2022 19:50:38 +0100 Subject: [PATCH 2/9] Prefer `make_awaitable` over `defer.succeed` in tests When configuring the return values of mocks, prefer awaitables from `make_awaitable` over `defer.succeed`. `Deferred`s are only awaitable once, so it is inappropriate for a mock to return the same `Deferred` multiple times. Signed-off-by: Sean Quah --- tests/federation/test_federation_client.py | 2 +- tests/federation/test_federation_sender.py | 2 +- tests/handlers/test_e2e_keys.py | 7 ++-- tests/handlers/test_password_providers.py | 34 +++++++++---------- tests/handlers/test_typing.py | 6 ++-- tests/handlers/test_user_directory.py | 6 ++-- tests/rest/client/test_presence.py | 4 +-- tests/rest/client/test_rooms.py | 7 ++-- .../test_resource_limits_server_notices.py | 28 +++++++-------- tests/storage/test_monthly_active_users.py | 9 +++-- tests/test_federation.py | 2 +- 11 files changed, 50 insertions(+), 57 deletions(-) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index ec8864dafe37..268a48d7ba5f 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -83,7 +83,7 @@ def test_get_room_state(self): ) # mock up the response, and have the agent return it - self._mock_agent.request.return_value = defer.succeed( + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( _mock_response( { "pdus": [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 91f982518e69..6b26353d5e93 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -226,7 +226,7 @@ def test_dont_send_device_updates_for_remote_users(self): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. self.hs.get_federation_transport_client().query_user_devices.return_value = ( - defer.succeed( + make_awaitable( { "stream_id": "1", "user_id": "@user2:host2", diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c74ed1fcffc..1e6ad4b663e9 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,7 +19,6 @@ from parameterized import parameterized from signedjson import key as key, sign as sign -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms @@ -704,7 +703,7 @@ def test_query_devices_remote_no_sync(self) -> None: remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_client_keys = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, "master_keys": { @@ -777,14 +776,14 @@ def test_query_devices_remote_sync(self) -> None: # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( - return_value=defer.succeed({"some_room_id"}) + return_value=make_awaitable({"some_room_id"}) ) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_user_devices = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index d401fda93855..addf14fa2ba0 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -17,8 +17,6 @@ from typing import Any, Type, Union from unittest.mock import Mock -from twisted.internet import defer - import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -190,7 +188,7 @@ def password_only_auth_provider_login_test_body(self): self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -226,13 +224,13 @@ def password_only_auth_provider_ui_auth_test_body(self): self.get_success(module_api.register_user("u")) # log in twice, to get two devices - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -246,7 +244,7 @@ def password_only_auth_provider_ui_auth_test_body(self): mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -260,7 +258,7 @@ def local_user_fallback_login_test_body(self): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 403, channel.result) @@ -277,7 +275,7 @@ def local_user_fallback_ui_auth_test_body(self): self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -320,7 +318,7 @@ def no_local_user_fallback_login_test_body(self): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -342,7 +340,7 @@ def no_local_user_fallback_ui_auth_test_body(self): self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -359,7 +357,7 @@ def no_local_user_fallback_ui_auth_test_body(self): mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -413,7 +411,7 @@ def custom_auth_provider_login_test_body(self): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -427,7 +425,7 @@ def custom_auth_provider_login_test_body(self): # try a weird username. Again, it's unclear what we *expect* to happen # in these cases, but at least we can guard against the API changing # unexpectedly - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") @@ -477,7 +475,7 @@ def custom_auth_provider_ui_auth_test_body(self): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) body["auth"]["test_field"] = "foo" @@ -490,7 +488,7 @@ def custom_auth_provider_ui_auth_test_body(self): mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) @@ -508,9 +506,9 @@ def test_custom_auth_provider_callback(self): self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self): - callback = Mock(return_value=defer.succeed(None)) + callback = Mock(return_value=make_awaitable(None)) - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -646,7 +644,7 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self): login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ffd5c4cb938c..5f2e26a5fce7 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -65,11 +65,11 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = defer.succeed(True) + mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -98,7 +98,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c6e501c7be56..96e2e3039ba8 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -15,7 +15,6 @@ from unittest.mock import Mock, patch from urllib.parse import quote -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +29,7 @@ from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -439,7 +439,7 @@ def test_handle_user_deactivated_support_user(self) -> None: ) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): @@ -454,7 +454,7 @@ def test_handle_user_deactivated_regular_user(self) -> None: self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 0abe378fe4d6..b3738a03046a 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -14,7 +14,6 @@ from http import HTTPStatus from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler @@ -24,6 +23,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) + presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 6ff79b9e2eed..9443daa0560a 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -22,7 +22,6 @@ from unittest.mock import Mock, call from urllib import parse as urlparse -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -1426,9 +1425,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] - {} - ) + self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] search_filter = {"generic_search_term": "foobar"} @@ -1456,7 +1453,7 @@ def test_fallback(self) -> None: # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), - defer.succeed({}), + make_awaitable({}), ) search_filter = {"generic_search_term": "foobar"} diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 02b96c9e6ecf..9ee9509d3a96 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -14,8 +14,6 @@ from unittest.mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin @@ -68,16 +66,16 @@ def prepare(self, reactor, clock, hs): return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( - return_value=defer.succeed(Mock()) + return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=defer.succeed("!something:localhost") + return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) @@ -95,7 +93,7 @@ def test_maybe_send_server_notice_to_user_flag_off(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -111,7 +109,8 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( @@ -130,7 +129,8 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self): Test when user does not have blocked notice, but should have one """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -141,7 +141,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -152,7 +152,7 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -167,7 +167,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): an alert message is not sent into the room """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -182,7 +182,7 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): Test that when a server is disabled, that MAU limit alerting is ignored. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -199,14 +199,14 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): is suppressed that the room is returned to an unblocked state. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( - return_value=defer.succeed((True, [])) + return_value=make_awaitable((True, [])) ) mock_event = Mock( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 60c8d3759481..0fbf46567091 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes @@ -259,10 +258,10 @@ def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) @@ -272,9 +271,9 @@ def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) d = self.store.populate_monthly_active_users("user_id") diff --git a/tests/test_federation.py b/tests/test_federation.py index c39816de855d..0cbef70bfa57 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -233,7 +233,7 @@ def test_cross_signing_keys_retry(self): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value=succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, From 301bfe365b0089f4a36aaed7fc13316c21b6b219 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 19 Apr 2022 19:53:01 +0100 Subject: [PATCH 3/9] Prefer `make_awaitable` over `defer.succeed` in `tests/rest/client/test_transactions.py` Signed-off-by: Sean Quah --- tests/rest/client/test_transactions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 8d8251b2ac99..21a1ca2a6885 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -22,6 +22,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -38,7 +39,7 @@ def setUp(self) -> None: @defer.inlineCallbacks def test_executes_given_function(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" ) @@ -47,7 +48,7 @@ def test_executes_given_function(self): @defer.inlineCallbacks def test_deduplicates_based_on_key(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg", changing_args=i @@ -130,7 +131,7 @@ def cb(): @defer.inlineCallbacks def test_cleans_up(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) From 9718ba703d498704ed40f2b3687f9dffb561d235 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 19 Apr 2022 20:21:39 +0100 Subject: [PATCH 4/9] Add newsfile Signed-off-by: Sean Quah --- changelog.d/12505.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12505.misc diff --git a/changelog.d/12505.misc b/changelog.d/12505.misc new file mode 100644 index 000000000000..a691d7962f89 --- /dev/null +++ b/changelog.d/12505.misc @@ -0,0 +1 @@ +Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests. From 180ab40c86d19c05b9a7017899c1051ff8ea186d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 19 Apr 2022 20:57:48 +0100 Subject: [PATCH 5/9] Appease mypy --- synapse/logging/context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 38cb1cc0e124..18bfdff40039 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -820,10 +820,10 @@ def run_in_background( # type: ignore[misc] ) except StopIteration as e: # ...which raises a `StopIteration` once the awaitable is complete. - res = e.value - - # res is now a plain value here. - return defer.succeed(res) + return defer.succeed(e.value) + else: + # `f` returned a plain value. + return defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can From 629ce1b7bd0acd582a28c7f278645b8c025e4b76 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 21 Apr 2022 14:27:55 +0100 Subject: [PATCH 6/9] Reword comments and reduce indent --- synapse/logging/context.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 18bfdff40039..3482fe278b7a 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -802,29 +802,30 @@ def run_in_background( # type: ignore[misc] # by synchronous exceptions, so let's turn them into Failures. return defer.fail() + # First we handle coroutines by wrapping them in a `Deferred`. if isinstance(res, typing.Coroutine): res = defer.ensureDeferred(res) - # At this point we should have a Deferred, if not then f was a synchronous - # function, wrap it in a Deferred for consistency. + # At this point, `res` may be a plain value, `Deferred`, or some other kind of + # non-coroutine awaitable. if not isinstance(res, defer.Deferred): - if isinstance(res, Awaitable): - # `f` returned some kind of awaitable that is not a coroutine or `Deferred`. - # We assume that it is a completed awaitable, such as a `DoneAwaitable` or - # `Future` from `make_awaitable`, and await it manually. - iterator = res.__await__() # `__await__` returns an iterator... - try: - next(iterator) - raise ValueError( - f"Function {f} returned an unresolved awaitable: {res}" - ) - except StopIteration as e: - # ...which raises a `StopIteration` once the awaitable is complete. - return defer.succeed(e.value) - else: - # `f` returned a plain value. + # Wrap plain values in a `Deferred`. + if not isinstance(res, Awaitable): return defer.succeed(res) + # `res` is some kind of awaitable that is not a coroutine or `Deferred`. + # We assume that it is a completed awaitable, such as a `DoneAwaitable` or + # `Future` from `make_awaitable`, and await it manually. + iterator = res.__await__() # `__await__` returns an iterator... + try: + next(iterator) + raise ValueError( + f"Function {f} returned an unresolved awaitable: {res}" + ) + except StopIteration as e: + # ...which raises a `StopIteration` once the awaitable is complete. + return defer.succeed(e.value) + if res.called and not res.paused: # The function should have maintained the logcontext, so we can # optimise out the messing about From ba28da545f3d117da0f0c6ac63b4dbac75de57b8 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 21 Apr 2022 14:28:44 +0100 Subject: [PATCH 7/9] run linter --- synapse/logging/context.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 3482fe278b7a..3347f0ec2613 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -819,9 +819,7 @@ def run_in_background( # type: ignore[misc] iterator = res.__await__() # `__await__` returns an iterator... try: next(iterator) - raise ValueError( - f"Function {f} returned an unresolved awaitable: {res}" - ) + raise ValueError(f"Function {f} returned an unresolved awaitable: {res}") except StopIteration as e: # ...which raises a `StopIteration` once the awaitable is complete. return defer.succeed(e.value) From daa0e9a19523cd4cb3f3da31609c099f0b6a6a75 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 22 Apr 2022 10:57:10 +0100 Subject: [PATCH 8/9] refactor run_in_background --- synapse/logging/context.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 3347f0ec2613..c62ffacc94c7 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -802,27 +802,23 @@ def run_in_background( # type: ignore[misc] # by synchronous exceptions, so let's turn them into Failures. return defer.fail() - # First we handle coroutines by wrapping them in a `Deferred`. + # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain + # value. Convert it to a `Deferred`. if isinstance(res, typing.Coroutine): + # Wrap the coroutine in a `Deferred`. res = defer.ensureDeferred(res) + elif isinstance(res, defer.Deferred): + pass + elif isinstance(res, Awaitable): + # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` + # or `Future` from `make_awaitable`. + async def awaiter(awaitable: Awaitable[R]) -> R: + return await awaitable - # At this point, `res` may be a plain value, `Deferred`, or some other kind of - # non-coroutine awaitable. - if not isinstance(res, defer.Deferred): - # Wrap plain values in a `Deferred`. - if not isinstance(res, Awaitable): - return defer.succeed(res) - - # `res` is some kind of awaitable that is not a coroutine or `Deferred`. - # We assume that it is a completed awaitable, such as a `DoneAwaitable` or - # `Future` from `make_awaitable`, and await it manually. - iterator = res.__await__() # `__await__` returns an iterator... - try: - next(iterator) - raise ValueError(f"Function {f} returned an unresolved awaitable: {res}") - except StopIteration as e: - # ...which raises a `StopIteration` once the awaitable is complete. - return defer.succeed(e.value) + res = defer.ensureDeferred(awaiter(res)) + else: + # `res` is a plain value. Wrap it in a `Deferred`. + res = defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can From 8ffb4917896141d45d25a1e0f6c6870234999e27 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 27 Apr 2022 14:31:01 +0100 Subject: [PATCH 9/9] Move `awaiter` function to top-level --- synapse/logging/context.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index c62ffacc94c7..fd9cb979208a 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -722,6 +722,11 @@ def nested_logging_context(suffix: str) -> LoggingContext: R = TypeVar("R") +async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R: + """Unwraps an arbitrary awaitable by awaiting it.""" + return await awaitable + + @overload def preserve_fn( # type: ignore[misc] f: Callable[P, Awaitable[R]], @@ -812,10 +817,7 @@ def run_in_background( # type: ignore[misc] elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. - async def awaiter(awaitable: Awaitable[R]) -> R: - return await awaitable - - res = defer.ensureDeferred(awaiter(res)) + res = defer.ensureDeferred(_unwrap_awaitable(res)) else: # `res` is a plain value. Wrap it in a `Deferred`. res = defer.succeed(res)