Skip to content

Commit

Permalink
Use loop.start_tls() to upgrade connections to SSL
Browse files Browse the repository at this point in the history
The old way of TLS upgrade (openining a connection, asking postgres
to do TLS and then duping the underlying socket) seems not to work
anymore on Windows with Python 3.8.
  • Loading branch information
1st1 committed Nov 20, 2019
1 parent d655a39 commit bdba7ce
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 80 deletions.
175 changes: 112 additions & 63 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
return addrs, params, config


class TLSUpgradeProto(asyncio.Protocol):
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
self.on_data = _create_future(loop)
self.host = host
self.port = port
self.ssl_context = ssl_context
self.ssl_is_advisory = ssl_is_advisory

def data_received(self, data):
if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
self.ssl_context.verify_mode == ssl_module.CERT_NONE and
data == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from
# sslmode=prefer (or sslmode=allow). But be extra sure to
# disallow insecure connections when the ssl context asks for
# real security.
self.on_data.set_result(False)
else:
self.on_data.set_exception(
ConnectionError(
'PostgreSQL server at "{host}:{port}" '
'rejected SSL upgrade'.format(
host=self.host, port=self.port)))

def connection_lost(self, exc):
if not self.on_data.done():
if exc is None:
exc = ConnectionError('unexpected connection_lost() call')
self.on_data.set_exception(exc)


async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):

if ssl_context is True:
ssl_context = ssl_module.create_default_context()

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
host, port)

tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.

try:
do_ssl_upgrade = await pr.on_data
except (Exception, asyncio.CancelledError):
tr.close()
raise

if hasattr(loop, 'start_tls'):
if do_ssl_upgrade:
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
except (Exception, asyncio.CancelledError):
tr.close()
raise
else:
new_tr = tr

pg_proto = protocol_factory()
pg_proto.connection_made(new_tr)
new_tr.set_protocol(pg_proto)

return new_tr, pg_proto
else:
conn_factory = functools.partial(
loop.create_connection, protocol_factory)

if do_ssl_upgrade:
conn_factory = functools.partial(
conn_factory, ssl=ssl_context, server_hostname=host)

sock = _get_socket(tr)
sock = sock.dup()
_set_nodelay(sock)
tr.close()

try:
return await conn_factory(sock=sock)
except (Exception, asyncio.CancelledError):
sock.close()
raise


async def _connect_addr(*, addr, loop, timeout, params, config,
connection_class):
assert loop is not None
Expand All @@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
else:
connector = loop.create_connection(proto_factory, *addr)

connector = asyncio.ensure_future(connector)

before = time.monotonic()
try:
tr, pr = await asyncio.wait_for(
Expand Down Expand Up @@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
raise last_error


async def _negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
server_hostname, ssl_is_advisory=False):
# Note: ssl_is_advisory only affects behavior when the server does not
# accept SSLRequests. If the SSLRequest is accepted but either the SSL
# negotiation fails or the PostgreSQL user isn't permitted to use SSL,
# there's nothing that would attempt to reconnect with a non-SSL socket.
reader, writer = await asyncio.open_connection(host, port)

tr = writer.transport
try:
sock = _get_socket(tr)
_set_nodelay(sock)

writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
await writer.drain()
resp = await reader.readexactly(1)

if resp == b'S':
conn_factory = functools.partial(
conn_factory, ssl=ssl, server_hostname=server_hostname)
elif (ssl_is_advisory and
ssl.verify_mode == ssl_module.CERT_NONE and
resp == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from sslmode=prefer
# (or sslmode=allow). But be extra sure to disallow insecure
# connections when the ssl context asks for real security.
pass
else:
raise ConnectionError(
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
host, port))

sock = sock.dup() # Must come before tr.close()
finally:
writer.close()
await compat.wait_closed(writer)

try:
return await conn_factory(sock=sock) # Must come after tr.close()
except (Exception, asyncio.CancelledError):
sock.close()
raise
async def _cancel(*, loop, addr, params: _ConnectionParameters,
backend_pid, backend_secret):

class CancelProto(asyncio.Protocol):

async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):
return await _negotiate_ssl_connection(
host, port,
functools.partial(loop.create_connection, protocol_factory),
loop=loop,
ssl=ssl_context,
server_hostname=host,
ssl_is_advisory=ssl_is_advisory)
def __init__(self):
self.on_disconnect = _create_future(loop)

def connection_lost(self, exc):
if not self.on_disconnect.done():
self.on_disconnect.set_result(True)

async def _open_connection(*, loop, addr, params: _ConnectionParameters):
if isinstance(addr, str):
r, w = await asyncio.open_unix_connection(addr)
tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
if params.ssl:
r, w = await _negotiate_ssl_connection(
tr, pr = await _create_ssl_connection(
CancelProto,
*addr,
asyncio.open_connection,
loop=loop,
ssl=params.ssl,
server_hostname=addr[0],
ssl_context=params.ssl,
ssl_is_advisory=params.ssl_is_advisory)
else:
r, w = await asyncio.open_connection(*addr)
_set_nodelay(_get_socket(w.transport))
tr, pr = await loop.create_connection(
CancelProto, *addr)
_set_nodelay(_get_socket(tr))

# Pack a CancelRequest message
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)

return r, w
try:
tr.write(msg)
await pr.on_disconnect
finally:
tr.close()


def _get_socket(transport):
Expand Down
22 changes: 5 additions & 17 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import collections
import collections.abc
import itertools
import struct
import sys
import time
import traceback
Expand Down Expand Up @@ -1186,24 +1185,16 @@ async def _cleanup_stmts(self):
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)

async def _cancel(self, waiter):
r = w = None

try:
# Open new connection to the server
r, w = await connect_utils._open_connection(
loop=self._loop, addr=self._addr, params=self._params)

# Pack CancelRequest message
msg = struct.pack('!llll', 16, 80877102,
self._protocol.backend_pid,
self._protocol.backend_secret)

w.write(msg)
await r.read() # Wait until EOF
await connect_utils._cancel(
loop=self._loop, addr=self._addr, params=self._params,
backend_pid=self._protocol.backend_pid,
backend_secret=self._protocol.backend_secret)
except ConnectionResetError as ex:
# On some systems Postgres will reset the connection
# after processing the cancellation command.
if r is None and not waiter.done():
if not waiter.done():
waiter.set_exception(ex)
except asyncio.CancelledError:
# There are two scenarios in which the cancellation
Expand All @@ -1221,9 +1212,6 @@ async def _cancel(self, waiter):
compat.current_asyncio_task(self._loop))
if not waiter.done():
waiter.set_result(None)
if w is not None:
w.close()
await compat.wait_closed(w)

def _cancel_current_command(self, waiter):
self._cancellations.add(self._loop.create_task(self._cancel(waiter)))
Expand Down

0 comments on commit bdba7ce

Please sign in to comment.