diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 5ba35719..2a8d8f29 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -652,6 +652,43 @@ async def runner(): self.assertIsNone( self.loop.run_until_complete(connection_lost_called)) + def test_context_run_segfault(self): + is_new = False + done = self.loop.create_future() + + def server(sock): + sock.sendall(b'hello') + + class Protocol(asyncio.Protocol): + def __init__(self): + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + try: + self = weakref.ref(self) + nonlocal is_new + if is_new: + done.set_result(data) + else: + is_new = True + new_proto = Protocol() + self().transport.set_protocol(new_proto) + new_proto.connection_made(self().transport) + new_proto.data_received(data) + except Exception as e: + done.set_exception(e) + + async def test(addr): + await self.loop.create_connection(Protocol, *addr) + data = await done + self.assertEqual(data, b'hello') + + with self.tcp_server(server) as srv: + self.loop.run_until_complete(test(srv.addr)) + class Test_UV_TCP(_TestTCP, tb.UVTestCase): diff --git a/uvloop/handles/basetransport.pyx b/uvloop/handles/basetransport.pyx index 6ddecc68..28b30794 100644 --- a/uvloop/handles/basetransport.pyx +++ b/uvloop/handles/basetransport.pyx @@ -70,7 +70,9 @@ cdef class UVBaseTransport(UVSocketHandle): try: # _maybe_pause_protocol() is always triggered from user-calls, # so we must copy the context to avoid entering context twice - self.context.copy().run(self._protocol.pause_writing) + run_in_context( + self.context.copy(), self._protocol.pause_writing, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -91,7 +93,9 @@ cdef class UVBaseTransport(UVSocketHandle): # We're copying the context to avoid entering context twice, # even though it's not always necessary to copy - it's easier # to copy here than passing down a copied context. - self.context.copy().run(self._protocol.resume_writing) + run_in_context( + self.context.copy(), self._protocol.resume_writing, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: diff --git a/uvloop/handles/stream.pyx b/uvloop/handles/stream.pyx index fe828bde..4757ce7a 100644 --- a/uvloop/handles/stream.pyx +++ b/uvloop/handles/stream.pyx @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport): except AttributeError: keep_open = False else: - keep_open = self.context.run(meth) + keep_open = run_in_context(self.context, meth) if keep_open: # We're keeping the connection open so the @@ -826,7 +826,11 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread]) + run_in_context1( + sc.context, + sc._protocol_data_received, + loop._recv_buffer[:nread], + ) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 @@ -911,7 +915,11 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream, sc._read_pybuf_acquired = 0 try: - buf = sc.context.run(sc._protocol_get_buffer, suggested_size) + buf = run_in_context1( + sc.context, + sc._protocol_get_buffer, + suggested_size, + ) PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE) got_buf = 1 except BaseException as exc: @@ -976,7 +984,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc.context.run(sc._protocol_buffer_updated, nread) + run_in_context1(sc.context, sc._protocol_buffer_updated, nread) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 921c3565..6e0f3576 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -66,7 +66,7 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.context.run(self.protocol_factory) + protocol = run_in_context(self.context, self.protocol_factory) if self.ssl is None: client = self._make_new_transport(protocol, None, self.context) diff --git a/uvloop/handles/udp.pyx b/uvloop/handles/udp.pyx index 82dabbf4..b92fdfd7 100644 --- a/uvloop/handles/udp.pyx +++ b/uvloop/handles/udp.pyx @@ -257,16 +257,18 @@ cdef class UDPTransport(UVBaseTransport): cdef _on_receive(self, bytes data, object exc, object addr): if exc is None: - self.context.run(self._protocol.datagram_received, data, addr) + run_in_context2( + self.context, self._protocol.datagram_received, data, addr, + ) else: - self.context.run(self._protocol.error_received, exc) + run_in_context1(self.context, self._protocol.error_received, exc) cdef _on_sent(self, object exc, object context=None): if exc is not None: if isinstance(exc, OSError): if context is None: context = self.context - context.run(self._protocol.error_received, exc) + run_in_context1(context, self._protocol.error_received, exc) else: self._fatal_error( exc, False, 'Fatal write error on datagram transport') diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 6a2bbe0c..4d96ffa6 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -89,6 +89,34 @@ cdef inline socket_dec_io_ref(sock): sock._decref_socketios() +cdef inline run_in_context(context, method): + # This method is internally used to workaround a reference issue that in + # certain circumstances, inlined context.run() will not hold a reference to + # the given method instance, which - if deallocated - will cause segault. + # See also: edgedb/edgedb#2222 + Py_INCREF(method) + try: + return context.run(method) + finally: + Py_DECREF(method) + + +cdef inline run_in_context1(context, method, arg): + Py_INCREF(method) + try: + return context.run(method, arg) + finally: + Py_DECREF(method) + + +cdef inline run_in_context2(context, method, arg1, arg2): + Py_INCREF(method) + try: + return context.run(method, arg1, arg2) + finally: + Py_DECREF(method) + + # Used for deprecation and removal of `loop.create_datagram_endpoint()`'s # *reuse_address* parameter _unset = object() diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index d1f976e3..3cc1df31 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -794,7 +794,9 @@ cdef class SSLProtocol: # inside the upstream callbacks like buffer_updated() keep_open = self._app_protocol.eof_received() else: - keep_open = context.run(self._app_protocol.eof_received) + keep_open = run_in_context( + context, self._app_protocol.eof_received, + ) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: @@ -817,7 +819,7 @@ cdef class SSLProtocol: # inside the upstream callbacks like buffer_updated() self._app_protocol.pause_writing() else: - context.run(self._app_protocol.pause_writing) + run_in_context(context, self._app_protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -836,7 +838,7 @@ cdef class SSLProtocol: # inside the upstream callbacks like resume_writing() self._app_protocol.resume_writing() else: - context.run(self._app_protocol.resume_writing) + run_in_context(context, self._app_protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: