Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-79156: Add start_tls() method to streams API #91453

Merged
merged 2 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Doc/library/asyncio-stream.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,24 @@ StreamWriter
be resumed. When there is nothing to wait for, the :meth:`drain`
returns immediately.

.. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
ssl_handshake_timeout=None)

Upgrade an existing stream-based connection to TLS.

Parameters:

* *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.

* *server_hostname*: sets or overrides the host name that the target
server's certificate will be matched against.

* *ssl_handshake_timeout* is the time in seconds to wait for the TLS
handshake to complete before aborting the connection. ``60.0`` seconds
if ``None`` (default).

.. versionadded:: 3.8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 3.11 isn't it ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, missed it. Good catch, @kumaraditya303. @arhadthedev, you could fix this while working on shutdown_tls.


.. method:: is_closing()

Return ``True`` if the stream is closed or in the process of
Expand Down
4 changes: 4 additions & 0 deletions Doc/whatsnew/3.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ asyncio
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
(Contributed by Alex Grönholm in :issue:`46805`.)

* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading
existing stream-based connections to TLS. (Contributed by Ian Good in
:issue:`34975`.)

fractions
---------

Expand Down
21 changes: 21 additions & 0 deletions Lib/asyncio/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ def _stream_reader(self):
return None
return self._stream_reader_wr()

def _replace_writer(self, writer):
loop = self._loop
transport = writer.transport
self._stream_writer = writer
self._transport = transport
self._over_ssl = transport.get_extra_info('sslcontext') is not None

def connection_made(self, transport):
if self._reject_connection:
context = {
Expand Down Expand Up @@ -371,6 +378,20 @@ async def drain(self):
await sleep(0)
await self._protocol._drain_helper()

async def start_tls(self, sslcontext, *,
server_hostname=None,
ssl_handshake_timeout=None):
"""Upgrade an existing stream-based connection to TLS."""
server_side = self._protocol._client_connected_cb is not None
protocol = self._protocol
await self.drain()
new_transport = await self._loop.start_tls( # type: ignore
self._transport, protocol, sslcontext,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
self._transport = new_transport
protocol._replace_writer(self)


class StreamReader:

Expand Down
63 changes: 63 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,69 @@ async def client(path):

self.assertEqual(messages, [])

@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls(self):

class MyServer:

def __init__(self, loop):
self.server = None
self.loop = loop

async def handle_client(self, client_reader, client_writer):
data1 = await client_reader.readline()
client_writer.write(data1)
await client_writer.drain()
assert client_writer.get_extra_info('sslcontext') is None
await client_writer.start_tls(
test_utils.simple_server_sslcontext())
assert client_writer.get_extra_info('sslcontext') is not None
data2 = await client_reader.readline()
client_writer.write(data2)
await client_writer.drain()
client_writer.close()
await client_writer.wait_closed()

def start(self):
sock = socket.create_server(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client,
sock=sock))
return sock.getsockname()

def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
writer.write(b"hello world 1!\n")
await writer.drain()
msgback1 = await reader.readline()
assert writer.get_extra_info('sslcontext') is None
await writer.start_tls(test_utils.simple_client_sslcontext())
assert writer.get_extra_info('sslcontext') is not None
writer.write(b"hello world 2!\n")
await writer.drain()
msgback2 = await reader.readline()
writer.close()
await writer.wait_closed()
return msgback1, msgback2

messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

server = MyServer(self.loop)
addr = server.start()
msg1, msg2 = self.loop.run_until_complete(client(addr))
server.stop()

self.assertEqual(messages, [])
self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n")

@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`,
which upgrades the connection with TLS using the given
:class:`~ssl.SSLContext`.