Skip to content

Commit

Permalink
SSL: schedule first data after waiter wakeup
Browse files Browse the repository at this point in the history
The waiter given to SSLProtocol should be woke up before the first
data callback, especially for `start_tls()` where the user protocol's
`connection_made()` won't be called and the waiter wakeup is the only
time the user have access to the new SSLTransport for the first time.
The user may want to e.g. check ALPN before handling the first data,
it's better that uvloop doesn't force the user to check this by
themselves.
  • Loading branch information
fantix committed Jul 2, 2021
1 parent c808a66 commit 2081db8
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
91 changes: 91 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,97 @@ def wrapper(sock):
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_first_data_after_wakeup(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()
loop = self.loop
this = self
fut = self.loop.create_future()

def client(sock, addr):
try:
sock.connect(addr)

incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = client_context.wrap_bio(incoming, outgoing)

# Do handshake manually so that we could collect the last piece
while True:
try:
sslobj.do_handshake()
break
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(65536))

# Send the first data together with the last handshake payload
sslobj.write(b'hello')
sock.send(outgoing.read())

while True:
try:
incoming.write(sock.recv(65536))
self.assertEqual(sslobj.read(1024), b'hello')
break
except ssl.SSLWantReadError:
pass

sock.close()

except Exception as ex:
loop.call_soon_threadsafe(fut.set_exception, ex)
sock.close()
else:
loop.call_soon_threadsafe(fut.set_result, None)

class EchoProto(asyncio.Protocol):
def connection_made(self, tr):
self.tr = tr
# manually run the coroutine, in order to avoid accidental data
coro = loop.start_tls(
tr, self, server_context,
server_side=True,
ssl_handshake_timeout=this.TIMEOUT,
)
waiter = coro.send(None)

def tls_started(_):
try:
coro.send(None)
except StopIteration as e:
# update self.tr to SSL transport as soon as we know it
self.tr = e.value

waiter.add_done_callback(tls_started)

def data_received(self, data):
# This is a dumb protocol that writes back whatever it receives
# regardless of whether self.tr is SSL or not
self.tr.write(data)

async def run_main():
proto = EchoProto()

server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()

with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT):
await asyncio.wait_for(fut, timeout=self.TIMEOUT)
proto.tr.close()

server.close()
await server.wait_closed()

self.loop.run_until_complete(run_main())


class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
pass
Expand Down
14 changes: 13 additions & 1 deletion uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,19 @@ cdef class SSLProtocol:
self._app_state = STATE_CON_MADE
self._app_protocol.connection_made(self._get_app_transport())
self._wakeup_waiter()
self._do_read()

# We should wakeup user code before sending the first data below. In
# case of `start_tls()`, the user can only get the SSLTransport in the
# wakeup callback, because `connection_made()` is not called again.
# We should schedule the first data later than the wakeup callback so
# that the user get a chance to e.g. check ALPN with the transport
# before having to handle the first data.
self._loop._call_soon_handle(
new_MethodHandle(self._loop,
"SSLProtocol._do_read",
<method_t> self._do_read,
None, # current context is good
self))

# Shutdown flow

Expand Down

0 comments on commit 2081db8

Please sign in to comment.