From 7df8b86d412ef1866e576ea3414b79d0a6125891 Mon Sep 17 00:00:00 2001 From: Ian Good Date: Sat, 4 May 2019 15:23:14 +0000 Subject: [PATCH 1/2] bpo-34975: Add start_tls() method to streams API The existing event loop `start_tls()` method is not sufficient for connections using the streams API. The existing StreamReader works because the new transport passes received data to the original protocol. The StreamWriter must then write data to the new transport, and the StreamReaderProtocol must be updated to close the new transport correctly. The new StreamWriter `start_tls()` updates itself and the reader protocol to the new SSL transport. --- Doc/library/asyncio-stream.rst | 18 +++++ Doc/whatsnew/3.11.rst | 4 ++ Lib/asyncio/streams.py | 21 ++++++ Lib/test/test_asyncio/test_streams.py | 66 +++++++++++++++++++ .../2019-05-06-23-36-34.bpo-34975.eb49jr.rst | 3 + 5 files changed, 112 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst index ba534f9903fb49..72355d356f2052 100644 --- a/Doc/library/asyncio-stream.rst +++ b/Doc/library/asyncio-stream.rst @@ -295,6 +295,24 @@ StreamWriter be resumed. When there is nothing to wait for, the :meth:`drain` returns immediately. + .. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \ + ssl_handshake_timeout=None) + + Upgrade an existing stream-based connection to TLS. + + Parameters: + + * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`. + + * *server_hostname*: sets or overrides the host name that the target + server's certificate will be matched against. + + * *ssl_handshake_timeout* is the time in seconds to wait for the TLS + handshake to complete before aborting the connection. ``60.0`` seconds + if ``None`` (default). + + .. versionadded:: 3.8 + .. method:: is_closing() Return ``True`` if the stream is closed or in the process of diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index df0b0a7fbebec9..3fa64df1b8bfad 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -246,6 +246,10 @@ asyncio :meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`. (Contributed by Alex Grönholm in :issue:`46805`.) +* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading + existing stream-based connections to TLS. (Contributed by Ian Good in + :issue:`34975`.) + fractions --------- diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 080d8a62cde1e2..a568c4e4b295f0 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -217,6 +217,13 @@ def _stream_reader(self): return None return self._stream_reader_wr() + def _replace_writer(self, writer): + loop = self._loop + transport = writer.transport + self._stream_writer = writer + self._transport = transport + self._over_ssl = transport.get_extra_info('sslcontext') is not None + def connection_made(self, transport): if self._reject_connection: context = { @@ -371,6 +378,20 @@ async def drain(self): await sleep(0) await self._protocol._drain_helper() + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade an existing stream-based connection to TLS.""" + server_side = self._protocol._client_connected_cb is not None + protocol = self._protocol + await self.drain() + new_transport = await self._loop.start_tls( # type: ignore + self._transport, protocol, sslcontext, + server_side=server_side, server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + self._transport = new_transport + protocol._replace_writer(self) + class StreamReader: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 227b2279e172c8..f1b0acb3450e3c 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -706,6 +706,72 @@ async def client(path): self.assertEqual(messages, []) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_tls(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + async def handle_client(self, client_reader, client_writer): + data1 = await client_reader.readline() + client_writer.write(data1) + await client_writer.drain() + assert client_writer.get_extra_info('sslcontext') is None + await client_writer.start_tls( + test_utils.simple_server_sslcontext()) + assert client_writer.get_extra_info('sslcontext') is not None + data2 = await client_reader.readline() + client_writer.write(data2) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock, + loop=self.loop)) + return sock.getsockname() + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, loop=self.loop) + writer.write(b"hello world 1!\n") + await writer.drain() + msgback1 = await reader.readline() + assert writer.get_extra_info('sslcontext') is None + await writer.start_tls(test_utils.simple_client_sslcontext()) + assert writer.get_extra_info('sslcontext') is not None + writer.write(b"hello world 2!\n") + await writer.drain() + msgback2 = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback1, msgback2 + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server = MyServer(self.loop) + addr = server.start() + msg1, msg2 = self.loop.run_until_complete( + asyncio.Task(client(addr), loop=self.loop)) + server.stop() + + self.assertEqual(messages, []) + self.assertEqual(msg1, b"hello world 1!\n") + self.assertEqual(msg2, b"hello world 2!\n") + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example diff --git a/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst new file mode 100644 index 00000000000000..1576269da99ee0 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst @@ -0,0 +1,3 @@ +Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`, +which upgrades the connection with TLS using the given +:class:`~ssl.SSLContext`. From bfe3a0c47185017b58868f2abf85a1489c89bb41 Mon Sep 17 00:00:00 2001 From: Oleg Iarygin Date: Tue, 12 Apr 2022 08:06:20 +0300 Subject: [PATCH 2/2] Use a current loop inherited from run_until_complete like other tests do --- Lib/test/test_asyncio/test_streams.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index f1b0acb3450e3c..a7d17894e1c526 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -733,8 +733,7 @@ def start(self): sock = socket.create_server(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, - sock=sock, - loop=self.loop)) + sock=sock)) return sock.getsockname() def stop(self): @@ -744,8 +743,7 @@ def stop(self): self.server = None async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, loop=self.loop) + reader, writer = await asyncio.open_connection(*addr) writer.write(b"hello world 1!\n") await writer.drain() msgback1 = await reader.readline() @@ -764,8 +762,7 @@ async def client(addr): server = MyServer(self.loop) addr = server.start() - msg1, msg2 = self.loop.run_until_complete( - asyncio.Task(client(addr), loop=self.loop)) + msg1, msg2 = self.loop.run_until_complete(client(addr)) server.stop() self.assertEqual(messages, [])