Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-34971: add support for TLS sessions from asyncio #9840

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def _make_ssl_transport(
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
call_connection_made=True):
call_connection_made=True,
ssl_session=None):
"""Create SSL transport."""
raise NotImplementedError

Expand Down Expand Up @@ -866,7 +867,8 @@ async def create_connection(
*, ssl=None, family=0,
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_session=None):
"""Connect to a TCP server.

Create a streaming transport connection to a given Internet host and
Expand Down Expand Up @@ -901,6 +903,9 @@ async def create_connection(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_session is not None and not ssl:
raise ValueError('ssl_session is only meaningful with ssl')

if host is not None or port is not None:
if sock is not None:
raise ValueError(
Expand Down Expand Up @@ -984,7 +989,8 @@ async def create_connection(

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_session=ssl_session)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
Expand All @@ -996,7 +1002,8 @@ async def create_connection(
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_session=None):

sock.setblocking(False)

Expand All @@ -1007,7 +1014,8 @@ async def _create_connection_transport(
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_session=ssl_session)
else:
transport = self._make_socket_transport(sock, protocol, waiter)

Expand Down
6 changes: 4 additions & 2 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,13 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_session=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_session=ssl_session)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down
6 changes: 4 additions & 2 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_session=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_session=ssl_session)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down
15 changes: 10 additions & 5 deletions Lib/asyncio/sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class _SSLPipe(object):

max_size = 256 * 1024 # Buffer size passed to read()

def __init__(self, context, server_side, server_hostname=None):
def __init__(self, context, server_side, server_hostname=None, session=None):
"""
The *context* argument specifies the ssl.SSLContext to use.

Expand All @@ -67,6 +67,7 @@ def __init__(self, context, server_side, server_hostname=None):
self._context = context
self._server_side = server_side
self._server_hostname = server_hostname
self._session = session
self._state = _UNWRAPPED
self._incoming = ssl.MemoryBIO()
self._outgoing = ssl.MemoryBIO()
Expand Down Expand Up @@ -117,7 +118,8 @@ def do_handshake(self, callback=None):
self._sslobj = self._context.wrap_bio(
self._incoming, self._outgoing,
server_side=self._server_side,
server_hostname=self._server_hostname)
server_hostname=self._server_hostname,
session=self._session)
self._state = _DO_HANDSHAKE
self._handshake_cb = callback
ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
Expand Down Expand Up @@ -412,7 +414,8 @@ class SSLProtocol(protocols.Protocol):
def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
call_connection_made=True,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_session=None):
if ssl is None:
raise RuntimeError('stdlib ssl module not available')

Expand All @@ -433,9 +436,10 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
else:
self._server_hostname = None
self._sslcontext = sslcontext
self._ssl_session = ssl_session
# SSL-specific extra info. More info are set when the handshake
# completes.
self._extra = dict(sslcontext=sslcontext)
self._extra = dict(sslcontext=sslcontext, ssl_session=ssl_session)

# App data write buffering
self._write_backlog = collections.deque()
Expand Down Expand Up @@ -478,7 +482,8 @@ def connection_made(self, transport):
self._transport = transport
self._sslpipe = _SSLPipe(self._sslcontext,
self._server_side,
self._server_hostname)
self._server_hostname,
self._ssl_session)
self._start_handshake()

def connection_lost(self, exc):
Expand Down
27 changes: 21 additions & 6 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,44 +1418,51 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
session = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_session=session)

def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
Expand Down Expand Up @@ -1486,6 +1493,14 @@ def test_create_connection_ssl_timeout_for_plain_socket(self):
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)

def test_create_connection_ssl_session_for_plain_socket(self):
coro = self.loop.create_connection(
MyProto, 'example.com', 80, ssl_session=object())
with self.assertRaisesRegex(
ValueError,
'ssl_session is only meaningful with ssl'):
self.loop.run_until_complete(coro)

def test_create_server_empty_host(self):
# if host is empty string use None instead
host = object()
Expand Down
63 changes: 63 additions & 0 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from asyncio import proactor_events
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test.ssl_servers import make_https_server
from test import support


Expand Down Expand Up @@ -614,6 +615,68 @@ def test_create_ssl_connection(self):
self._test_create_ssl_connection(httpd, create_connection,
peername=httpd.address)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_ssl_connection_with_session(self):
server_context = test_utils.simple_server_sslcontext()
server = make_https_server(self, context=server_context)

client_context = test_utils.simple_client_sslcontext()
# TODO: sessions aren't compatible with TLSv1.3 yet
client_context.options |= ssl.OP_NO_TLSv1_3

def new_conn(*, session=None):
create_connection = functools.partial(
self.loop.create_connection,
lambda: MyProto(loop=self.loop),
'localhost', server.port)
conn_fut = create_connection(ssl=client_context, ssl_session=session)
tr, pr = self.loop.run_until_complete(conn_fut)
self.loop.run_until_complete(pr.done)
sslobj = tr.get_extra_info('ssl_object')
stats = {
'session': sslobj.session,
'session_reused': sslobj.session_reused,
}
tr.close()
return stats

# first connection without session
stats = new_conn()
session = stats['session']
self.assertTrue(session.id)
self.assertGreater(session.time, 0)
self.assertGreater(session.timeout, 0)
self.assertTrue(session.has_ticket)
if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
self.assertGreater(session.ticket_lifetime_hint, 0)
self.assertFalse(stats['session_reused'])

# reuse session
stats = new_conn(session=session)
self.assertTrue(stats['session_reused'])
session2 = stats['session']
self.assertEqual(session2.id, session.id)
self.assertEqual(session2, session)
self.assertIsNot(session2, session)
self.assertGreaterEqual(session2.time, session.time)
self.assertGreaterEqual(session2.timeout, session.timeout)

# another one without session
stats = new_conn()
self.assertFalse(stats['session_reused'])
session3 = stats['session']
self.assertNotEqual(session3.id, session.id)
self.assertNotEqual(session3, session)

# reuse session again
stats = new_conn(session=session)
self.assertTrue(stats['session_reused'])
session4 = stats['session']
self.assertEqual(session4.id, session.id)
self.assertEqual(session4, session)
self.assertGreaterEqual(session4.time, session.time)
self.assertGreaterEqual(session4.timeout, session.timeout)

@support.skip_unless_bind_unix_socket
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_ssl_unix_connection(self):
Expand Down