From f0b9e65c007491fabf13b4abf4697cb2c213c4ce Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sun, 14 Feb 2021 21:20:34 -0500 Subject: [PATCH] Fix ref issue when protocol is in Cython Because `context.run()` doesn't hold reference to the callable, when e.g. the protocol is written in Cython, the callbacks were not guaranteed to hold the protocol reference. This PR fixes the issue by explicitly add a reference before `context.run()` calls. Refs edgedb/edgedb#2222 --- tests/test_tcp.py | 37 ++++++++++++++++++++++++++++++++ uvloop/handles/basetransport.pyx | 8 +++++-- uvloop/handles/stream.pyx | 16 ++++++++++---- uvloop/handles/streamserver.pyx | 2 +- uvloop/handles/udp.pyx | 8 ++++--- uvloop/loop.pyx | 28 ++++++++++++++++++++++++ uvloop/sslproto.pyx | 8 ++++--- 7 files changed, 94 insertions(+), 13 deletions(-) diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 5ba35719..2a8d8f29 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -652,6 +652,43 @@ async def runner(): self.assertIsNone( self.loop.run_until_complete(connection_lost_called)) + def test_context_run_segfault(self): + is_new = False + done = self.loop.create_future() + + def server(sock): + sock.sendall(b'hello') + + class Protocol(asyncio.Protocol): + def __init__(self): + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + try: + self = weakref.ref(self) + nonlocal is_new + if is_new: + done.set_result(data) + else: + is_new = True + new_proto = Protocol() + self().transport.set_protocol(new_proto) + new_proto.connection_made(self().transport) + new_proto.data_received(data) + except Exception as e: + done.set_exception(e) + + async def test(addr): + await self.loop.create_connection(Protocol, *addr) + data = await done + self.assertEqual(data, b'hello') + + with self.tcp_server(server) as srv: + self.loop.run_until_complete(test(srv.addr)) + class Test_UV_TCP(_TestTCP, tb.UVTestCase): diff --git a/uvloop/handles/basetransport.pyx b/uvloop/handles/basetransport.pyx index 6ddecc68..28b30794 100644 --- a/uvloop/handles/basetransport.pyx +++ b/uvloop/handles/basetransport.pyx @@ -70,7 +70,9 @@ cdef class UVBaseTransport(UVSocketHandle): try: # _maybe_pause_protocol() is always triggered from user-calls, # so we must copy the context to avoid entering context twice - self.context.copy().run(self._protocol.pause_writing) + run_in_context( + self.context.copy(), self._protocol.pause_writing, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -91,7 +93,9 @@ cdef class UVBaseTransport(UVSocketHandle): # We're copying the context to avoid entering context twice, # even though it's not always necessary to copy - it's easier # to copy here than passing down a copied context. - self.context.copy().run(self._protocol.resume_writing) + run_in_context( + self.context.copy(), self._protocol.resume_writing, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: diff --git a/uvloop/handles/stream.pyx b/uvloop/handles/stream.pyx index fe828bde..4757ce7a 100644 --- a/uvloop/handles/stream.pyx +++ b/uvloop/handles/stream.pyx @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport): except AttributeError: keep_open = False else: - keep_open = self.context.run(meth) + keep_open = run_in_context(self.context, meth) if keep_open: # We're keeping the connection open so the @@ -826,7 +826,11 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread]) + run_in_context1( + sc.context, + sc._protocol_data_received, + loop._recv_buffer[:nread], + ) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 @@ -911,7 +915,11 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream, sc._read_pybuf_acquired = 0 try: - buf = sc.context.run(sc._protocol_get_buffer, suggested_size) + buf = run_in_context1( + sc.context, + sc._protocol_get_buffer, + suggested_size, + ) PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE) got_buf = 1 except BaseException as exc: @@ -976,7 +984,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc.context.run(sc._protocol_buffer_updated, nread) + run_in_context1(sc.context, sc._protocol_buffer_updated, nread) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 921c3565..6e0f3576 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -66,7 +66,7 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.context.run(self.protocol_factory) + protocol = run_in_context(self.context, self.protocol_factory) if self.ssl is None: client = self._make_new_transport(protocol, None, self.context) diff --git a/uvloop/handles/udp.pyx b/uvloop/handles/udp.pyx index 82dabbf4..b92fdfd7 100644 --- a/uvloop/handles/udp.pyx +++ b/uvloop/handles/udp.pyx @@ -257,16 +257,18 @@ cdef class UDPTransport(UVBaseTransport): cdef _on_receive(self, bytes data, object exc, object addr): if exc is None: - self.context.run(self._protocol.datagram_received, data, addr) + run_in_context2( + self.context, self._protocol.datagram_received, data, addr, + ) else: - self.context.run(self._protocol.error_received, exc) + run_in_context1(self.context, self._protocol.error_received, exc) cdef _on_sent(self, object exc, object context=None): if exc is not None: if isinstance(exc, OSError): if context is None: context = self.context - context.run(self._protocol.error_received, exc) + run_in_context1(context, self._protocol.error_received, exc) else: self._fatal_error( exc, False, 'Fatal write error on datagram transport') diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 6a2bbe0c..4d96ffa6 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -89,6 +89,34 @@ cdef inline socket_dec_io_ref(sock): sock._decref_socketios() +cdef inline run_in_context(context, method): + # This method is internally used to workaround a reference issue that in + # certain circumstances, inlined context.run() will not hold a reference to + # the given method instance, which - if deallocated - will cause segault. + # See also: edgedb/edgedb#2222 + Py_INCREF(method) + try: + return context.run(method) + finally: + Py_DECREF(method) + + +cdef inline run_in_context1(context, method, arg): + Py_INCREF(method) + try: + return context.run(method, arg) + finally: + Py_DECREF(method) + + +cdef inline run_in_context2(context, method, arg1, arg2): + Py_INCREF(method) + try: + return context.run(method, arg1, arg2) + finally: + Py_DECREF(method) + + # Used for deprecation and removal of `loop.create_datagram_endpoint()`'s # *reuse_address* parameter _unset = object() diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index d1f976e3..3cc1df31 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -794,7 +794,9 @@ cdef class SSLProtocol: # inside the upstream callbacks like buffer_updated() keep_open = self._app_protocol.eof_received() else: - keep_open = context.run(self._app_protocol.eof_received) + keep_open = run_in_context( + context, self._app_protocol.eof_received, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: @@ -817,7 +819,7 @@ cdef class SSLProtocol: # inside the upstream callbacks like buffer_updated() self._app_protocol.pause_writing() else: - context.run(self._app_protocol.pause_writing) + run_in_context(context, self._app_protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -836,7 +838,7 @@ cdef class SSLProtocol: # inside the upstream callbacks like resume_writing() self._app_protocol.resume_writing() else: - context.run(self._app_protocol.resume_writing) + run_in_context(context, self._app_protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: