Skip to content

Commit

Permalink
Experimental fix for #169
Browse files Browse the repository at this point in the history
Co-authored-by: Andrey Egorov <andr06@gmail.com>
  • Loading branch information
1st1 and andr-04 committed Jun 22, 2018
1 parent a332533 commit cb0a65a
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 39 deletions.
85 changes: 85 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,91 @@ def test_socket_sync_remove_and_immediately_close(self):
self.assertEqual(sock.fileno(), -1)
self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

def test_sock_cancel_add_reader_race(self):
srv_sock_conn = None

async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()), loop=self.loop)
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
srv_sock_conn.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with srv_sock_conn:
await fut

async def client(addr):
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[self.loop.sock_recv(sock_client, 1)],
timeout=1, loop=self.loop)

async def send_server_data():
# Wait a little bit to let reader future cancel and
# schedule the removal of the reader callback. Right after
# "rfut.cancel()" we will call "loop.sock_recv()", which
# will add a reader. This will make a race between
# remove- and add-reader.
await asyncio.sleep(0.1, loop=self.loop)
await self.loop.sock_sendall(srv_sock_conn, b'1')
self.loop.create_task(send_server_data())

for rfut in pending_read_futs:
rfut.cancel()

data = await self.loop.sock_recv(sock_client, 1)

self.assertEqual(data, b'1')

self.loop.run_until_complete(server())

def test_sock_send_before_cancel(self):
srv_sock_conn = None

async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()), loop=self.loop)
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
with srv_sock_conn:
await fut

async def client(addr):
await asyncio.sleep(0.01, loop=self.loop)
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[self.loop.sock_recv(sock_client, 1)],
timeout=1, loop=self.loop)

# server can send the data in a random time, even before
# the previous result future has cancelled.
await self.loop.sock_sendall(srv_sock_conn, b'1')

for rfut in pending_read_futs:
rfut.cancel()

data = await self.loop.sock_recv(sock_client, 1)

self.assertEqual(data, b'1')

self.loop.run_until_complete(server())


class TestUVSockets(_TestSockets, tb.UVTestCase):

Expand Down
2 changes: 2 additions & 0 deletions uvloop/handles/poll.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cdef class UVPoll(UVHandle):
cdef int is_active(self)

cdef is_reading(self)
cdef is_writing(self)

cdef start_reading(self, Handle callback)
cdef start_writing(self, Handle callback)
cdef stop_reading(self)
Expand Down
3 changes: 3 additions & 0 deletions uvloop/handles/poll.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ cdef class UVPoll(UVHandle):
cdef is_reading(self):
return self._is_alive() and self.reading_handle is not None

cdef is_writing(self):
return self._is_alive() and self.writing_handle is not None

cdef start_reading(self, Handle callback):
cdef:
int mask = 0
Expand Down
4 changes: 2 additions & 2 deletions uvloop/loop.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ cdef class Loop:
cdef _track_process(self, UVProcess proc)
cdef _untrack_process(self, UVProcess proc)

cdef _new_reader_future(self, sock)
cdef _new_writer_future(self, sock)
cdef _add_reader(self, fd, Handle handle)
cdef _has_reader(self, fd)
cdef _remove_reader(self, fd)

cdef _add_writer(self, fd, Handle handle)
cdef _has_writer(self, fd)
cdef _remove_writer(self, fd)

cdef _sock_recv(self, fut, sock, n)
Expand Down
130 changes: 93 additions & 37 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,20 @@ cdef class Loop:

return result

cdef _has_reader(self, fileobj):
cdef:
UVPoll poll

self._check_closed()
fd = self._fileobj_to_fd(fileobj)

try:
poll = <UVPoll>(self._polls[fd])
except KeyError:
return False

return poll.is_reading()

cdef _add_writer(self, fileobj, Handle handle):
cdef:
UVPoll poll
Expand Down Expand Up @@ -791,6 +805,20 @@ cdef class Loop:

return result

cdef _has_writer(self, fileobj):
cdef:
UVPoll poll

self._check_closed()
fd = self._fileobj_to_fd(fileobj)

try:
poll = <UVPoll>(self._polls[fd])
except KeyError:
return False

return poll.is_writing()

cdef _getaddrinfo(self, object host, object port,
int family, int type,
int proto, int flags,
Expand Down Expand Up @@ -845,35 +873,17 @@ cdef class Loop:
nr.query(addr, flags)
return fut

cdef _new_reader_future(self, sock):
def _on_cancel(fut):
# Check if the future was cancelled and if the socket
# is still open, i.e.
#
# loop.remove_reader(sock)
# sock.close()
# fut.cancel()
#
# wasn't called by the user.
if fut.cancelled() and sock.fileno() != -1:
self._remove_reader(sock)

fut = self._new_future()
fut.add_done_callback(_on_cancel)
return fut

cdef _new_writer_future(self, sock):
def _on_cancel(fut):
if fut.cancelled() and sock.fileno() != -1:
self._remove_writer(sock)

fut = self._new_future()
fut.add_done_callback(_on_cancel)
return fut

cdef _sock_recv(self, fut, sock, n):
cdef:
Handle handle
if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_recv is called on a cancelled Future')

if not self._has_reader(sock):
raise RuntimeError(
f'socket {sock!r} does not have a reader '
f'in the _sock_recv callback')

try:
data = sock.recv(n)
Expand All @@ -889,8 +899,16 @@ cdef class Loop:
self._remove_reader(sock)

cdef _sock_recv_into(self, fut, sock, buf):
cdef:
Handle handle
if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_recv_into is called on a cancelled Future')

if not self._has_reader(sock):
raise RuntimeError(
f'socket {sock!r} does not have a reader '
f'in the _sock_recv_into callback')

try:
data = sock.recv_into(buf)
Expand All @@ -910,6 +928,17 @@ cdef class Loop:
Handle handle
int n

if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_sendall is called on a cancelled Future')

if not self._has_writer(sock):
raise RuntimeError(
f'socket {sock!r} does not have a writer '
f'in the _sock_sendall callback')

try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
Expand Down Expand Up @@ -940,9 +969,6 @@ cdef class Loop:
self._add_writer(sock, handle)

cdef _sock_accept(self, fut, sock):
cdef:
Handle handle

try:
conn, address = sock.accept()
conn.setblocking(False)
Expand Down Expand Up @@ -2261,7 +2287,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_recv",
Expand All @@ -2287,7 +2313,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_recv_into",
Expand Down Expand Up @@ -2338,7 +2364,7 @@ cdef class Loop:
data = memoryview(data)
data = data[n:]

fut = self._new_writer_future(sock)
fut = _SyncSocketWriterFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_sendall",
Expand Down Expand Up @@ -2368,7 +2394,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle2(
self,
"Loop._sock_accept",
Expand Down Expand Up @@ -2952,6 +2978,36 @@ cdef inline void __loop_free_buffer(Loop loop):
loop._recv_buffer_in_use = 0


class _SyncSocketReaderFuture(aio_Future):

def __init__(self, sock, loop):
aio_Future.__init__(self, loop=loop)
self.__sock = sock
self.__loop = loop

def cancel(self):
if self.__sock is not None and self.__sock.fileno() != -1:
self.__loop.remove_reader(self.__sock)
self.__sock = None

aio_Future.cancel(self)


class _SyncSocketWriterFuture(aio_Future):

def __init__(self, sock, loop):
aio_Future.__init__(self, loop=loop)
self.__sock = sock
self.__loop = loop

def cancel(self):
if self.__sock is not None and self.__sock.fileno() != -1:
self.__loop.remove_writer(self.__sock)
self.__sock = None

aio_Future.cancel(self)


include "cbhandles.pyx"
include "pseudosock.pyx"

Expand Down

0 comments on commit cb0a65a

Please sign in to comment.