Skip to content

Commit

Permalink
Fix context in protocol callbacks (#348)
Browse files Browse the repository at this point in the history
This is a combined fix to correct contexts from which protocal callbacks
are invoked. In short, callbacks like data_received() should always be
invoked from consistent contexts which are copied from the context where
the underlying UVHandle is created or started.

The new test case covers also asyncio, but skipping the failing ones.
  • Loading branch information
fantix committed Feb 5, 2021
1 parent 7b202cc commit f691212
Show file tree
Hide file tree
Showing 19 changed files with 791 additions and 132 deletions.
623 changes: 601 additions & 22 deletions tests/test_context.py

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,10 @@ def test_socket_sync_remove_and_immediately_close(self):
self.loop.run_until_complete(asyncio.sleep(0.01))

def test_sock_cancel_add_reader_race(self):
if self.is_asyncio_loop():
if sys.version_info[:2] == (3, 8):
# asyncio 3.8.x has a regression; fixed in 3.9.0
# tracked in https://bugs.python.org/issue30064
raise unittest.SkipTest()
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
# asyncio 3.8.x has a regression; fixed in 3.9.0
# tracked in https://bugs.python.org/issue30064
raise unittest.SkipTest()

srv_sock_conn = None

Expand Down Expand Up @@ -247,8 +246,8 @@ async def send_server_data():
self.loop.run_until_complete(server())

def test_sock_send_before_cancel(self):
if self.is_asyncio_loop() and sys.version_info[:3] == (3, 8, 0):
# asyncio 3.8.0 seems to have a regression;
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
# asyncio 3.8.x has a regression; fixed in 3.9.0
# tracked in https://bugs.python.org/issue30064
raise unittest.SkipTest()

Expand Down
31 changes: 16 additions & 15 deletions uvloop/cbhandles.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -333,71 +333,72 @@ cdef new_Handle(Loop loop, object callback, object args, object context):
return handle


cdef new_MethodHandle(Loop loop, str name, method_t callback, object ctx):
cdef new_MethodHandle(Loop loop, str name, method_t callback, object context,
object bound_to):
cdef Handle handle
handle = Handle.__new__(Handle)
handle._set_loop(loop)
handle._set_context(None)
handle._set_context(context)

handle.cb_type = 2
handle.meth_name = name

handle.callback = <void*> callback
handle.arg1 = ctx
handle.arg1 = bound_to

return handle


cdef new_MethodHandle1(Loop loop, str name, method1_t callback,
object ctx, object arg):
cdef new_MethodHandle1(Loop loop, str name, method1_t callback, object context,
object bound_to, object arg):

cdef Handle handle
handle = Handle.__new__(Handle)
handle._set_loop(loop)
handle._set_context(None)
handle._set_context(context)

handle.cb_type = 3
handle.meth_name = name

handle.callback = <void*> callback
handle.arg1 = ctx
handle.arg1 = bound_to
handle.arg2 = arg

return handle


cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object ctx,
object arg1, object arg2):
cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object context,
object bound_to, object arg1, object arg2):

cdef Handle handle
handle = Handle.__new__(Handle)
handle._set_loop(loop)
handle._set_context(None)
handle._set_context(context)

handle.cb_type = 4
handle.meth_name = name

handle.callback = <void*> callback
handle.arg1 = ctx
handle.arg1 = bound_to
handle.arg2 = arg1
handle.arg3 = arg2

return handle


cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object ctx,
object arg1, object arg2, object arg3):
cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object context,
object bound_to, object arg1, object arg2, object arg3):

cdef Handle handle
handle = Handle.__new__(Handle)
handle._set_loop(loop)
handle._set_context(None)
handle._set_context(context)

handle.cb_type = 5
handle.meth_name = name

handle.callback = <void*> callback
handle.arg1 = ctx
handle.arg1 = bound_to
handle.arg2 = arg1
handle.arg3 = arg2
handle.arg4 = arg3
Expand Down
11 changes: 9 additions & 2 deletions uvloop/handles/basetransport.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ cdef class UVBaseTransport(UVSocketHandle):
new_MethodHandle(self._loop,
"UVTransport._call_connection_made",
<method_t>self._call_connection_made,
self.context,
self))

cdef inline _schedule_call_connection_lost(self, exc):
self._loop._call_soon_handle(
new_MethodHandle1(self._loop,
"UVTransport._call_connection_lost",
<method1_t>self._call_connection_lost,
self.context,
self, exc))

cdef _fatal_error(self, exc, throw, reason=None):
Expand Down Expand Up @@ -66,7 +68,9 @@ cdef class UVBaseTransport(UVSocketHandle):
if not self._protocol_paused:
self._protocol_paused = 1
try:
self._protocol.pause_writing()
# _maybe_pause_protocol() is always triggered from user-calls,
# so we must copy the context to avoid entering context twice
self.context.copy().run(self._protocol.pause_writing)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand All @@ -84,7 +88,10 @@ cdef class UVBaseTransport(UVSocketHandle):
if self._protocol_paused and size <= self._low_water:
self._protocol_paused = 0
try:
self._protocol.resume_writing()
# We're copying the context to avoid entering context twice,
# even though it's not always necessary to copy - it's easier
# to copy here than passing down a copied context.
self.context.copy().run(self._protocol.resume_writing)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand Down
1 change: 1 addition & 0 deletions uvloop/handles/handle.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cdef class UVHandle:
readonly _source_traceback
bint _closed
bint _inited
object context

# Added to enable current UDPTransport implementation,
# which doesn't use libuv handles.
Expand Down
2 changes: 1 addition & 1 deletion uvloop/handles/pipe.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cdef class UnixTransport(UVStream):

@staticmethod
cdef UnixTransport new(Loop loop, object protocol, Server server,
object waiter)
object waiter, object context)

cdef connect(self, char* addr)

Expand Down
18 changes: 12 additions & 6 deletions uvloop/handles/pipe.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ cdef class UnixServer(UVStreamServer):

self._mark_as_open()

cdef UVStream _make_new_transport(self, object protocol, object waiter):
cdef UVStream _make_new_transport(self, object protocol, object waiter,
object context):
cdef UnixTransport tr
tr = UnixTransport.new(self._loop, protocol, self._server, waiter)
tr = UnixTransport.new(self._loop, protocol, self._server, waiter,
context)
return <UVStream>tr


Expand All @@ -84,11 +86,11 @@ cdef class UnixTransport(UVStream):

@staticmethod
cdef UnixTransport new(Loop loop, object protocol, Server server,
object waiter):
object waiter, object context):

cdef UnixTransport handle
handle = UnixTransport.__new__(UnixTransport)
handle._init(loop, protocol, server, waiter)
handle._init(loop, protocol, server, waiter, context)
__pipe_init_uv_handle(<UVStream>handle, loop)
return handle

Expand All @@ -112,7 +114,9 @@ cdef class ReadUnixTransport(UVStream):
object waiter):
cdef ReadUnixTransport handle
handle = ReadUnixTransport.__new__(ReadUnixTransport)
handle._init(loop, protocol, server, waiter)
# This is only used in connect_read_pipe() and subprocess_shell/exec()
# directly, we could simply copy the current context.
handle._init(loop, protocol, server, waiter, Context_CopyCurrent())
__pipe_init_uv_handle(<UVStream>handle, loop)
return handle

Expand Down Expand Up @@ -162,7 +166,9 @@ cdef class WriteUnixTransport(UVStream):
# close the transport.
handle._close_on_read_error()

handle._init(loop, protocol, server, waiter)
# This is only used in connect_write_pipe() and subprocess_shell/exec()
# directly, we could simply copy the current context.
handle._init(loop, protocol, server, waiter, Context_CopyCurrent())
__pipe_init_uv_handle(<UVStream>handle, loop)
return handle

Expand Down
19 changes: 15 additions & 4 deletions uvloop/handles/process.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cdef class UVProcess(UVHandle):
self._fds_to_close = set()
self._preexec_fn = None
self._restore_signals = True
self.context = Context_CopyCurrent()

cdef _close_process_handle(self):
# XXX: This is a workaround for a libuv bug:
Expand Down Expand Up @@ -364,7 +365,8 @@ cdef class UVProcessTransport(UVProcess):
UVProcess._on_exit(self, exit_status, term_signal)

if self._stdio_ready:
self._loop.call_soon(self._protocol.process_exited)
self._loop.call_soon(self._protocol.process_exited,
context=self.context)
else:
self._pending_calls.append((_CALL_PROCESS_EXITED, None, None))

Expand All @@ -383,14 +385,16 @@ cdef class UVProcessTransport(UVProcess):

cdef _pipe_connection_lost(self, int fd, exc):
if self._stdio_ready:
self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc)
self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc,
context=self.context)
self._try_finish()
else:
self._pending_calls.append((_CALL_PIPE_CONNECTION_LOST, fd, exc))

cdef _pipe_data_received(self, int fd, data):
if self._stdio_ready:
self._loop.call_soon(self._protocol.pipe_data_received, fd, data)
self._loop.call_soon(self._protocol.pipe_data_received, fd, data,
context=self.context)
else:
self._pending_calls.append((_CALL_PIPE_DATA_RECEIVED, fd, data))

Expand Down Expand Up @@ -517,6 +521,7 @@ cdef class UVProcessTransport(UVProcess):

cdef _call_connection_made(self, waiter):
try:
# we're always called in the right context, so just call the user's
self._protocol.connection_made(self)
except (KeyboardInterrupt, SystemExit):
raise
Expand Down Expand Up @@ -556,7 +561,9 @@ cdef class UVProcessTransport(UVProcess):
self._finished = 1

if self._stdio_ready:
self._loop.call_soon(self._protocol.connection_lost, None)
# copy self.context for simplicity
self._loop.call_soon(self._protocol.connection_lost, None,
context=self.context)
else:
self._pending_calls.append((_CALL_CONNECTION_LOST, None, None))

Expand All @@ -572,6 +579,7 @@ cdef class UVProcessTransport(UVProcess):
new_MethodHandle1(self._loop,
"UVProcessTransport._call_connection_made",
<method1_t>self._call_connection_made,
None, # means to copy the current context
self, waiter))

@staticmethod
Expand All @@ -598,6 +606,8 @@ cdef class UVProcessTransport(UVProcess):
if handle._init_futs:
handle._stdio_ready = 0
init_fut = aio_gather(*handle._init_futs)
# add_done_callback will copy the current context and run the
# callback within the context
init_fut.add_done_callback(
ft_partial(handle.__stdio_inited, waiter))
else:
Expand All @@ -606,6 +616,7 @@ cdef class UVProcessTransport(UVProcess):
new_MethodHandle1(loop,
"UVProcessTransport._call_connection_made",
<method1_t>handle._call_connection_made,
None, # means to copy the current context
handle, waiter))

return handle
Expand Down
2 changes: 1 addition & 1 deletion uvloop/handles/stream.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cdef class UVStream(UVBaseTransport):
# All "inline" methods are final

cdef inline _init(self, Loop loop, object protocol, Server server,
object waiter)
object waiter, object context)

cdef inline _exec_write(self)

Expand Down
12 changes: 6 additions & 6 deletions uvloop/handles/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport):
except AttributeError:
keep_open = False
else:
keep_open = meth()
keep_open = self.context.run(meth)

if keep_open:
# We're keeping the connection open so the
Expand All @@ -631,8 +631,8 @@ cdef class UVStream(UVBaseTransport):
self._shutdown()

cdef inline _init(self, Loop loop, object protocol, Server server,
object waiter):

object waiter, object context):
self.context = context
self._set_protocol(protocol)
self._start_init(loop)

Expand Down Expand Up @@ -826,7 +826,7 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream,
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_total += 1

sc._protocol_data_received(loop._recv_buffer[:nread])
sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread])
except BaseException as exc:
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_errors_total += 1
Expand Down Expand Up @@ -911,7 +911,7 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream,

sc._read_pybuf_acquired = 0
try:
buf = sc._protocol_get_buffer(suggested_size)
buf = sc.context.run(sc._protocol_get_buffer, suggested_size)
PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE)
got_buf = 1
except BaseException as exc:
Expand Down Expand Up @@ -976,7 +976,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream,
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_total += 1

sc._protocol_buffer_updated(nread)
sc.context.run(sc._protocol_buffer_updated, nread)
except BaseException as exc:
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_errors_total += 1
Expand Down
4 changes: 2 additions & 2 deletions uvloop/handles/streamserver.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ cdef class UVStreamServer(UVSocketHandle):
object protocol_factory
bint opened
Server _server
object listen_context

# All "inline" methods are final

Expand All @@ -23,4 +22,5 @@ cdef class UVStreamServer(UVSocketHandle):
cdef inline listen(self)
cdef inline _on_listen(self)

cdef UVStream _make_new_transport(self, object protocol, object waiter)
cdef UVStream _make_new_transport(self, object protocol, object waiter,
object context)
Loading

0 comments on commit f691212

Please sign in to comment.