diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index c70eee649c57..d5ea76b96baa 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -937,14 +937,19 @@ def err_back(result: R) -> R: else: if inspect.isawaitable(result): - logger.error( - "@trace may not have wrapped %s correctly! " - "The function is not async but returned a %s.", - func.__qualname__, - type(result).__name__, - ) - - scope.__exit__(None, None, None) + + async def await_coroutine(): + try: + return await result + finally: + scope.__exit__(None, None, None) + + # The original method returned a coroutine, so we create another + # coroutine wrapping it, that calls __exit__. + return await_coroutine() + else: + # Just a simple sync function + scope.__exit__(None, None, None) return result diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index e28ba84cc2b7..29e30993543b 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast +from typing import Awaitable, cast from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactorClock @@ -277,3 +277,34 @@ async def fixture_async_func() -> str: [span.operation_name for span in self._reporter.get_spans()], ["fixture_async_func"], ) + + def test_trace_decorator_awaitable_return(self) -> None: + """ + Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` + with functions that return an awaitable (e.g. a coroutine) + """ + reactor = MemoryReactorClock() + + with LoggingContext("root context"): + # Something we can return without `await` to get a coroutine + async def fixture_async_func() -> str: + return "foo" + + # The actual kind of function we want to test that returns an awaitable + @trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer) + @tag_args + def fixture_awaitable_return_func() -> Awaitable[str]: + return fixture_async_func() + + d1 = defer.ensureDeferred(fixture_awaitable_return_func()) + + # let the tasks complete + reactor.pump((2,) * 8) + + self.assertEqual(self.successResultOf(d1), "foo") + + # the span should have been reported + self.assertEqual( + [span.operation_name for span in self._reporter.get_spans()], + ["fixture_awaitable_return_func"], + )