diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst index ba534f9903fb49..72355d356f2052 100644 --- a/Doc/library/asyncio-stream.rst +++ b/Doc/library/asyncio-stream.rst @@ -295,6 +295,24 @@ StreamWriter be resumed. When there is nothing to wait for, the :meth:`drain` returns immediately. + .. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \ + ssl_handshake_timeout=None) + + Upgrade an existing stream-based connection to TLS. + + Parameters: + + * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`. + + * *server_hostname*: sets or overrides the host name that the target + server's certificate will be matched against. + + * *ssl_handshake_timeout* is the time in seconds to wait for the TLS + handshake to complete before aborting the connection. ``60.0`` seconds + if ``None`` (default). + + .. versionadded:: 3.8 + .. method:: is_closing() Return ``True`` if the stream is closed or in the process of diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index df0b0a7fbebec9..3fa64df1b8bfad 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -246,6 +246,10 @@ asyncio :meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`. (Contributed by Alex Grönholm in :issue:`46805`.) +* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading + existing stream-based connections to TLS. (Contributed by Ian Good in + :issue:`34975`.) + fractions --------- diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 080d8a62cde1e2..a568c4e4b295f0 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -217,6 +217,13 @@ def _stream_reader(self): return None return self._stream_reader_wr() + def _replace_writer(self, writer): + loop = self._loop + transport = writer.transport + self._stream_writer = writer + self._transport = transport + self._over_ssl = transport.get_extra_info('sslcontext') is not None + def connection_made(self, transport): if self._reject_connection: context = { @@ -371,6 +378,20 @@ async def drain(self): await sleep(0) await self._protocol._drain_helper() + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade an existing stream-based connection to TLS.""" + server_side = self._protocol._client_connected_cb is not None + protocol = self._protocol + await self.drain() + new_transport = await self._loop.start_tls( # type: ignore + self._transport, protocol, sslcontext, + server_side=server_side, server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + self._transport = new_transport + protocol._replace_writer(self) + class StreamReader: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 227b2279e172c8..a7d17894e1c526 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -706,6 +706,69 @@ async def client(path): self.assertEqual(messages, []) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_tls(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + async def handle_client(self, client_reader, client_writer): + data1 = await client_reader.readline() + client_writer.write(data1) + await client_writer.drain() + assert client_writer.get_extra_info('sslcontext') is None + await client_writer.start_tls( + test_utils.simple_server_sslcontext()) + assert client_writer.get_extra_info('sslcontext') is not None + data2 = await client_reader.readline() + client_writer.write(data2) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock)) + return sock.getsockname() + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + writer.write(b"hello world 1!\n") + await writer.drain() + msgback1 = await reader.readline() + assert writer.get_extra_info('sslcontext') is None + await writer.start_tls(test_utils.simple_client_sslcontext()) + assert writer.get_extra_info('sslcontext') is not None + writer.write(b"hello world 2!\n") + await writer.drain() + msgback2 = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback1, msgback2 + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server = MyServer(self.loop) + addr = server.start() + msg1, msg2 = self.loop.run_until_complete(client(addr)) + server.stop() + + self.assertEqual(messages, []) + self.assertEqual(msg1, b"hello world 1!\n") + self.assertEqual(msg2, b"hello world 2!\n") + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example diff --git a/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst new file mode 100644 index 00000000000000..1576269da99ee0 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst @@ -0,0 +1,3 @@ +Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`, +which upgrades the connection with TLS using the given +:class:`~ssl.SSLContext`.