Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44011: New asyncio ssl implementation #31275

Merged
merged 3 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def restore(self):
class Server(events.AbstractServer):

def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout):
ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
Expand All @@ -278,6 +278,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None

Expand Down Expand Up @@ -309,7 +310,8 @@ def _start_serving(self):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout)
self, self._backlog, self._ssl_handshake_timeout,
self._ssl_shutdown_timeout)

def get_loop(self):
return self._loop
Expand Down Expand Up @@ -463,6 +465,7 @@ def _make_ssl_transport(
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
Expand Down Expand Up @@ -965,6 +968,7 @@ async def create_connection(
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.

Expand Down Expand Up @@ -1000,6 +1004,10 @@ async def create_connection(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
Expand Down Expand Up @@ -1075,7 +1083,8 @@ async def create_connection(

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
Expand All @@ -1087,7 +1096,8 @@ async def create_connection(
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

sock.setblocking(False)

Expand All @@ -1098,7 +1108,8 @@ async def _create_connection_transport(
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)

Expand Down Expand Up @@ -1189,7 +1200,8 @@ async def _sendfile_fallback(self, transp, file, offset, count):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade transport to TLS.

Return a new transport that *protocol* should start using
Expand All @@ -1212,6 +1224,7 @@ async def start_tls(self, transport, protocol, sslcontext, *,
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

# Pause early so that "ssl_protocol.data_received()" doesn't
Expand Down Expand Up @@ -1397,6 +1410,7 @@ async def create_server(
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.

Expand All @@ -1420,6 +1434,10 @@ async def create_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and ssl is None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if host is not None or port is not None:
if sock is not None:
raise ValueError(
Expand Down Expand Up @@ -1492,7 +1510,8 @@ async def create_server(
sock.setblocking(False)

server = Server(self, sockets, protocol_factory,
ssl, backlog, ssl_handshake_timeout)
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
Expand All @@ -1506,17 +1525,23 @@ async def create_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')

if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
Expand Down
7 changes: 7 additions & 0 deletions Lib/asyncio/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0

# Number of seconds to wait for SSL shutdown to complete
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT = 30.0

# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256

FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB

# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
Expand Down
19 changes: 16 additions & 3 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ async def create_connection(
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError

Expand All @@ -312,6 +313,7 @@ async def create_server(
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.

Expand Down Expand Up @@ -352,6 +354,10 @@ async def create_server(
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.

ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown procedure
before aborting the connection. Default is 30s.

start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
Expand All @@ -370,7 +376,8 @@ async def sendfile(self, transport, file, offset=0, count=None,
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade a transport to TLS.

Return a new transport that *protocol* should start using
Expand All @@ -382,13 +389,15 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
raise NotImplementedError

async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.

Expand All @@ -410,6 +419,9 @@ async def create_unix_server(
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).

ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for the SSL shutdown to finish (defaults to 30s).

start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
Expand All @@ -420,7 +432,8 @@ async def create_unix_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Handle an accepted connection.

This is used by servers that accept connections outside of
Expand Down
12 changes: 8 additions & 4 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,11 +642,13 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down Expand Up @@ -812,7 +814,8 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

def loop(f=None):
try:
Expand All @@ -826,7 +829,8 @@ def loop(f=None):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
self._make_socket_transport(
conn, protocol,
Expand Down
31 changes: 20 additions & 11 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout
)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down Expand Up @@ -146,15 +150,17 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)

def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
Expand Down Expand Up @@ -185,20 +191,22 @@ def _accept_connection(
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout)
backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)
self.create_task(accept)

async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
protocol = None
transport = None
try:
Expand All @@ -208,7 +216,8 @@ async def _accept_connection2(
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
Expand Down
Loading