diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 2a8d8f29..5b213d2c 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -3090,6 +3090,97 @@ def wrapper(sock): with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) + def test_first_data_after_wakeup(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + server_context = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + client_context = self._create_client_ssl_context() + loop = self.loop + this = self + fut = self.loop.create_future() + + def client(sock, addr): + try: + sock.connect(addr) + + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = client_context.wrap_bio(incoming, outgoing) + + # Do handshake manually so that we could collect the last piece + while True: + try: + sslobj.do_handshake() + break + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(65536)) + + # Send the first data together with the last handshake payload + sslobj.write(b'hello') + sock.send(outgoing.read()) + + while True: + try: + incoming.write(sock.recv(65536)) + self.assertEqual(sslobj.read(1024), b'hello') + break + except ssl.SSLWantReadError: + pass + + sock.close() + + except Exception as ex: + loop.call_soon_threadsafe(fut.set_exception, ex) + sock.close() + else: + loop.call_soon_threadsafe(fut.set_result, None) + + class EchoProto(asyncio.Protocol): + def connection_made(self, tr): + self.tr = tr + # manually run the coroutine, in order to avoid accidental data + coro = loop.start_tls( + tr, self, server_context, + server_side=True, + ssl_handshake_timeout=this.TIMEOUT, + ) + waiter = coro.send(None) + + def tls_started(_): + try: + coro.send(None) + except StopIteration as e: + # update self.tr to SSL transport as soon as we know it + self.tr = e.value + + waiter.add_done_callback(tls_started) + + def data_received(self, data): + # This is a dumb protocol that writes back whatever it receives + # regardless of whether self.tr is SSL or not + self.tr.write(data) + + async def run_main(): + proto = EchoProto() + + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr), + timeout=self.TIMEOUT): + await asyncio.wait_for(fut, timeout=self.TIMEOUT) + proto.tr.close() + + server.close() + await server.wait_closed() + + self.loop.run_until_complete(run_main()) + class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase): pass diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 3cc1df31..42bb7644 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -540,7 +540,19 @@ cdef class SSLProtocol: self._app_state = STATE_CON_MADE self._app_protocol.connection_made(self._get_app_transport()) self._wakeup_waiter() - self._do_read() + + # We should wakeup user code before sending the first data below. In + # case of `start_tls()`, the user can only get the SSLTransport in the + # wakeup callback, because `connection_made()` is not called again. + # We should schedule the first data later than the wakeup callback so + # that the user get a chance to e.g. check ALPN with the transport + # before having to handle the first data. + self._loop._call_soon_handle( + new_MethodHandle(self._loop, + "SSLProtocol._do_read", + self._do_read, + None, # current context is good + self)) # Shutdown flow