Skip to content

Commit

Permalink
Add sslmode=allow support and fix =prefer retry (#720)
Browse files Browse the repository at this point in the history
We didn't really retry the connection without SSL if the first SSL
connection fails under sslmode=prefer, that led to an issue when the
server has SSL support but explicitly denies SSL connection through
pg_hba.conf. This commit adds a retry in a new connection, which
makes it easy to implement the sslmode=allow retry.

Fixes #716
  • Loading branch information
fantix authored Mar 24, 2021
1 parent 93a238c commit 075114c
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 57 deletions.
148 changes: 110 additions & 38 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import collections
import enum
import functools
import getpass
import os
Expand All @@ -28,14 +29,29 @@
from . import protocol


class SSLMode(enum.IntEnum):
disable = 0
allow = 1
prefer = 2
require = 3
verify_ca = 4
verify_full = 5

@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'ssl_is_advisory',
'sslmode',
'connect_timeout',
'server_settings',
])
Expand Down Expand Up @@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
if isinstance(ssl, str):
SSLMODES = {
'disable': 0,
'allow': 1,
'prefer': 2,
'require': 3,
'verify-ca': 4,
'verify-full': 5,
}
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMODES[ssl]
except KeyError:
modes = ', '.join(SSLMODES.keys())
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.InterfaceError(
'`sslmode` parameter must be one of: {}'.format(modes))

# sslmode 'allow' is currently handled as 'prefer' because we're
# missing the "retry with SSL" behavior for 'allow', but do have the
# "retry without SSL" behavior for 'prefer'.
# Not changing 'allow' to 'prefer' here would be effectively the same
# as changing 'allow' to 'disable'.
if sslmode == SSLMODES['allow']:
sslmode = SSLMODES['prefer']

# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
if sslmode <= SSLMODES['allow']:
if sslmode < SSLMode.allow:
ssl = False
ssl_is_advisory = sslmode >= SSLMODES['allow']
else:
ssl = ssl_module.create_default_context()
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
ssl.check_hostname = sslmode >= SSLMode.verify_full
ssl.verify_mode = ssl_module.CERT_REQUIRED
if sslmode <= SSLMODES['require']:
if sslmode <= SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
ssl_is_advisory = sslmode <= SSLMODES['prefer']
elif ssl is True:
ssl = ssl_module.create_default_context()
sslmode = SSLMode.verify_full
else:
sslmode = SSLMode.disable

if server_settings is not None and (
not isinstance(server_settings, dict) or
Expand All @@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
sslmode=sslmode, connect_timeout=connect_timeout,
server_settings=server_settings)

return addrs, params
Expand Down Expand Up @@ -520,9 +519,8 @@ def data_received(self, data):
data == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from
# sslmode=prefer (or sslmode=allow). But be extra sure to
# disallow insecure connections when the ssl context asks for
# real security.
# sslmode=prefer. But be extra sure to disallow insecure
# connections when the ssl context asks for real security.
self.on_data.set_result(False)
else:
self.on_data.set_exception(
Expand Down Expand Up @@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
new_tr = tr

pg_proto = protocol_factory()
pg_proto.is_ssl = do_ssl_upgrade
pg_proto.connection_made(new_tr)
new_tr.set_protocol(pg_proto)

Expand All @@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
tr.close()

try:
return await conn_factory(sock=sock)
new_tr, pg_proto = await conn_factory(sock=sock)
pg_proto.is_ssl = do_ssl_upgrade
return new_tr, pg_proto
except (Exception, asyncio.CancelledError):
sock.close()
raise
Expand All @@ -605,8 +606,6 @@ async def _connect_addr(
if timeout <= 0:
raise asyncio.TimeoutError

connected = _create_future(loop)

params_input = params
if callable(params.password):
if inspect.iscoroutinefunction(params.password):
Expand All @@ -615,6 +614,49 @@ async def _connect_addr(
password = params.password()

params = params._replace(password=password)
args = (addr, loop, config, connection_class, record_class, params_input)

# prepare the params (which attempt has ssl) for the 2 attempts
if params.sslmode == SSLMode.allow:
params_retry = params
params = params._replace(ssl=None)
elif params.sslmode == SSLMode.prefer:
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, timeout, False, *args)

# first attempt
before = time.monotonic()
try:
return await __connect_addr(params, timeout, True, *args)
except _Retry:
pass

# second attempt
timeout -= time.monotonic() - before
if timeout <= 0:
raise asyncio.TimeoutError
else:
return await __connect_addr(params_retry, timeout, False, *args)


class _Retry(Exception):
pass


async def __connect_addr(
params,
timeout,
retry,
addr,
loop,
config,
connection_class,
record_class,
params_input,
):
connected = _create_future(loop)

proto_factory = lambda: protocol.Protocol(
addr, connected, params, record_class, loop)
Expand All @@ -625,7 +667,7 @@ async def _connect_addr(
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
ssl_is_advisory=params.ssl_is_advisory)
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
connector = loop.create_connection(proto_factory, *addr)

Expand All @@ -638,6 +680,35 @@ async def _connect_addr(
if timeout <= 0:
raise asyncio.TimeoutError
await compat.wait_for(connected, timeout=timeout)
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
):
tr.close()

# retry=True here is a redundant check because we don't want to
# accidentally raise the internal _Retry to the outer world
if retry and (
params.sslmode == SSLMode.allow and not pr.is_ssl or
params.sslmode == SSLMode.prefer and pr.is_ssl
):
# Trigger retry when:
# 1. First attempt with sslmode=allow, ssl=None failed
# 2. First attempt with sslmode=prefer, ssl=ctx failed while the
# server claimed to support SSL (returning "S" for SSLRequest)
# (likely because pg_hba.conf rejected the connection)
raise _Retry()

else:
# but will NOT retry if:
# 1. First attempt with sslmode=prefer failed but the server
# doesn't support SSL (returning 'N' for SSLRequest), because
# we already tried to connect without SSL thru ssl_is_advisory
# 2. Second attempt with sslmode=prefer, ssl=None failed
# 3. Second attempt with sslmode=allow, ssl=ctx failed
# 4. Any other sslmode
raise

except (Exception, asyncio.CancelledError):
tr.close()
raise
Expand Down Expand Up @@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):

def __init__(self):
self.on_disconnect = _create_future(loop)
self.is_ssl = False

def connection_lost(self, exc):
if not self.on_disconnect.done():
Expand All @@ -692,13 +764,13 @@ def connection_lost(self, exc):
if isinstance(addr, str):
tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
if params.ssl:
if params.ssl and params.sslmode != SSLMode.allow:
tr, pr = await _create_ssl_connection(
CancelProto,
*addr,
loop=loop,
ssl_context=params.ssl,
ssl_is_advisory=params.ssl_is_advisory)
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
tr, pr = await loop.create_connection(
CancelProto, *addr)
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
if SSL connection fails
- ``'allow'`` - currently equivalent to ``'prefer'``
- ``'allow'`` - try without SSL first, then retry with SSL if the first
attempt fails.
- ``'require'`` - only try an SSL connection. Certificate
verification errors are ignored
- ``'verify-ca'`` - only try an SSL connection, and verify
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):

readonly uint64_t queries_count

bint _is_ssl

PreparedStatementState statement

cdef get_connection(self)
Expand Down
10 changes: 10 additions & 0 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):

self.queries_count = 0

self._is_ssl = False

try:
self.create_future = loop.create_future
except AttributeError:
Expand Down Expand Up @@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
def resume_writing(self):
self.writing_allowed.set()

@property
def is_ssl(self):
return self._is_ssl

@is_ssl.setter
def is_ssl(self, value):
self._is_ssl = value


class Timer:
def __init__(self, budget):
Expand Down
Loading

0 comments on commit 075114c

Please sign in to comment.