Skip to content

Commit

Permalink
fix missing data on EOF in flushing
Browse files Browse the repository at this point in the history
* when EOF is received and data is still pending in incoming buffer,
  the data will be lost before this fix
* also removed sleep from a recent-written test
  • Loading branch information
fantix authored and 1st1 committed Oct 28, 2019
1 parent 695a520 commit 6476aad
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 71 deletions.
70 changes: 29 additions & 41 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2606,35 +2606,6 @@ def server(sock):
self.assertEqual(len(data), CHUNK * SIZE)
sock.close()

def openssl_server(sock):
conn = openssl_ssl.Connection(sslctx_openssl, sock)
conn.set_accept_state()

while True:
try:
data = conn.recv(16384)
self.assertEqual(data, b'ping')
break
except openssl_ssl.WantReadError:
pass

# use renegotiation to queue data in peer _write_backlog
conn.renegotiate()
conn.send(b'pong')

data_size = 0
while True:
try:
chunk = conn.recv(16384)
if not chunk:
break
data_size += len(chunk)
except openssl_ssl.WantReadError:
pass
except openssl_ssl.ZeroReturnError:
break
self.assertEqual(data_size, CHUNK * SIZE)

def run(meth):
def wrapper(sock):
try:
Expand All @@ -2652,12 +2623,18 @@ async def client(addr):
*addr,
ssl=client_sslctx,
server_hostname='')
sslprotocol = writer.get_extra_info('uvloop.sslproto')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

sslprotocol.pause_writing()
for _ in range(SIZE):
writer.write(b'x' * CHUNK)

writer.close()
sslprotocol.resume_writing()

await self.wait_closed(writer)
try:
data = await reader.read()
Expand All @@ -2669,9 +2646,6 @@ async def client(addr):
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))

with self.tcp_server(run(openssl_server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_remote_shutdown_receives_trailing_data(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()
Expand Down Expand Up @@ -2892,20 +2866,26 @@ async def client(addr, ctx):
self.assertIsNone(ctx())

def test_shutdown_timeout_handler_not_set(self):
if self.implementation == 'asyncio':
# asyncio cannot receive EOF after resume_reading()
raise unittest.SkipTest()

loop = self.loop
eof = asyncio.Event()
extra = None

def server(sock):
sslctx = self._create_server_ssl_context(self.ONLYCERT,
self.ONLYKEY)
sock = sslctx.wrap_socket(sock, server_side=True)
sock.send(b'hello')
assert sock.recv(1024) == b'world'
time.sleep(0.1)
sock.send(b'extra bytes' * 1)
sock.send(b'extra bytes')
# sending EOF here
sock.shutdown(socket.SHUT_WR)
loop.call_soon_threadsafe(eof.set)
# make sure we have enough time to reproduce the issue
time.sleep(0.1)
assert sock.recv(1024) == b''
sock.close()

class Protocol(asyncio.Protocol):
Expand All @@ -2917,20 +2897,28 @@ def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
self.transport.write(b'world')
# pause reading would make incoming data stay in the sslobj
self.transport.pause_reading()
# resume for AIO to pass
loop.call_later(0.2, self.transport.resume_reading)
if data == b'hello':
self.transport.write(b'world')
# pause reading would make incoming data stay in the sslobj
self.transport.pause_reading()
else:
nonlocal extra
extra = data

def connection_lost(self, exc):
self.fut.set_result(None)
if exc is None:
self.fut.set_result(None)
else:
self.fut.set_exception(exc)

async def client(addr):
ctx = self._create_client_ssl_context()
tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
await eof.wait()
tr.resume_reading()
await pr.fut
tr.close()
assert extra == b'extra bytes'

with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))
Expand Down
1 change: 1 addition & 0 deletions uvloop/sslproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ cdef class SSLProtocol:

bint _ssl_writing_paused
bint _app_reading_paused
bint _eof_received

size_t _incoming_high_water
size_t _incoming_low_water
Expand Down
47 changes: 17 additions & 30 deletions uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ cdef class SSLProtocol:
self._incoming_high_water = 0
self._incoming_low_water = 0
self._set_read_buffer_limits()
self._eof_received = False

self._app_writing_paused = False
self._outgoing_high_water = 0
Expand Down Expand Up @@ -391,6 +392,7 @@ cdef class SSLProtocol:
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
self._eof_received = True
try:
if self._loop.get_debug():
aio_logger.debug("%r received EOF", self)
Expand All @@ -400,9 +402,10 @@ cdef class SSLProtocol:

elif self._state == WRAPPED:
self._set_state(FLUSHING)
self._do_write()
self._set_state(SHUTDOWN)
self._do_shutdown()
if self._app_reading_paused:
return True
else:
self._do_flush()

elif self._state == FLUSHING:
self._do_write()
Expand All @@ -412,11 +415,14 @@ cdef class SSLProtocol:
elif self._state == SHUTDOWN:
self._do_shutdown()

finally:
except Exception:
self._transport.close()
raise

cdef _get_extra_info(self, name, default=None):
if name in self._extra:
if name == 'uvloop.sslproto':
return self
elif name in self._extra:
return self._extra[name]
elif self._transport is not None:
return self._transport.get_extra_info(name, default)
Expand Down Expand Up @@ -555,33 +561,14 @@ cdef class SSLProtocol:
aio_TimeoutError('SSL shutdown timed out'))

cdef _do_flush(self):
if self._write_backlog:
try:
while True:
# data is discarded when FLUSHING
chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE))
if not chunk_size:
# close_notify
break
except ssl_SSLAgainErrors as exc:
pass
except ssl_SSLError as exc:
self._on_shutdown_complete(exc)
return

try:
self._do_write()
except Exception as exc:
self._on_shutdown_complete(exc)
return

if not self._write_backlog:
self._set_state(SHUTDOWN)
self._do_shutdown()
self._do_read()
self._set_state(SHUTDOWN)
self._do_shutdown()

cdef _do_shutdown(self):
try:
self._sslobj.unwrap()
if not self._eof_received:
self._sslobj.unwrap()
except ssl_SSLAgainErrors as exc:
self._process_outgoing()
except ssl_SSLError as exc:
Expand Down Expand Up @@ -655,7 +642,7 @@ cdef class SSLProtocol:
# Incoming flow

cdef _do_read(self):
if self._state != WRAPPED:
if self._state != WRAPPED and self._state != FLUSHING:
return
try:
if not self._app_reading_paused:
Expand Down

0 comments on commit 6476aad

Please sign in to comment.