Skip to content

Commit

Permalink
Prefer SSL connections by default
Browse files Browse the repository at this point in the history
Switch the default SSL mode from 'disabled' to 'prefer'.  This matches
libpq's behavior and is a sensible thing to do.

Fixes: #654
  • Loading branch information
elprans committed Nov 26, 2020
1 parent 690048d commit 8b39a33
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 36 deletions.
17 changes: 5 additions & 12 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None:
ssl = os.getenv('PGSSLMODE')

if ssl is None:
ssl = 'prefer'

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
if isinstance(ssl, str):
Expand Down Expand Up @@ -435,14 +438,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if sslmode <= SSLMODES['require']:
ssl.verify_mode = ssl_module.CERT_NONE
ssl_is_advisory = sslmode <= SSLMODES['prefer']

if ssl:
for addr in addrs:
if isinstance(addr, str):
# UNIX socket
raise exceptions.InterfaceError(
'`ssl` parameter can only be enabled for TCP addresses, '
'got a UNIX socket path: {!r}'.format(addr))
elif ssl is True:
ssl = ssl_module.create_default_context()

if server_settings is not None and (
not isinstance(server_settings, dict) or
Expand Down Expand Up @@ -542,9 +539,6 @@ def connection_lost(self, exc):
async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):

if ssl_context is True:
ssl_context = ssl_module.create_default_context()

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
Expand Down Expand Up @@ -625,7 +619,6 @@ async def _connect_addr(

if isinstance(addr, str):
# UNIX socket
assert not params.ssl
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl:
connector = _create_ssl_connection(
Expand Down
26 changes: 25 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,7 +1864,28 @@ async def connect(dsn=None, *,
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() <create_default_context_>`_
will be used.
will be used. The value can also be one of the following strings:
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
if SSL connection fails
- ``'allow'`` - currently equivalent to ``'prefer'``
- ``'require'`` - only try an SSL connection. Certificate
verifiction errors are ignored
- ``'verify-ca'`` - only try an SSL connection, and verify
that the server certificate is issued by a trusted certificate
authority (CA)
- ``'verify-full'`` - only try an SSL connection, verify
that the server certificate is issued by a trusted CA and
that the requested server host name matches that in the
certificate.
The default is ``'prefer'``: try an SSL connection and fallback to
non-SSL connection if that fails.
.. note::
*ssl* is ignored for Unix domain socket communication.
:param dict server_settings:
An optional dict of server runtime parameters. Refer to
Expand Down Expand Up @@ -1921,6 +1942,9 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.22.0
The *ssl* argument now defaults to ``'prefer'``.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down
42 changes: 19 additions & 23 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
'result': ([('host', 123)], {
'user': 'user',
'password': 'passw',
'database': 'testdb'})
'database': 'testdb',
'ssl': True,
'ssl_is_advisory': True})
},

{
Expand Down Expand Up @@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
'user': 'user3',
'password': '123123',
'database': 'abcdef',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': True})
},

Expand Down Expand Up @@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
'user': 'me',
'password': 'ask',
'database': 'db',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': False})
},

Expand Down Expand Up @@ -617,7 +619,7 @@ def run_testcase(self, testcase):
password = testcase.get('password')
passfile = testcase.get('passfile')
database = testcase.get('database')
ssl = testcase.get('ssl')
sslmode = testcase.get('ssl')
server_settings = testcase.get('server_settings')

expected = testcase.get('result')
Expand All @@ -640,21 +642,25 @@ def run_testcase(self, testcase):

addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=ssl,
passfile=passfile, database=database, ssl=sslmode,
connect_timeout=None, server_settings=server_settings)

params = {k: v for k, v in params._asdict().items()
if v is not None}
params = {
k: v for k, v in params._asdict().items() if v is not None
}

if isinstance(params.get('ssl'), ssl.SSLContext):
params['ssl'] = True

result = (addrs, params)

if expected is not None:
for k, v in expected[1].items():
# If `expected` contains a type, allow that to "match" any
# instance of that type tyat `result` may contain. We need
# this because different SSLContexts don't compare equal.
if isinstance(v, type) and isinstance(result[1].get(k), v):
result[1][k] = v
if 'ssl' not in expected[1]:
# Avoid the hassle of specifying the default SSL mode
# unless explicitly tested for.
params.pop('ssl', None)
params.pop('ssl_is_advisory', None)

self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))

def test_test_connect_params_environ(self):
Expand Down Expand Up @@ -1063,16 +1069,6 @@ async def verify_fails(sslmode):
await verify_fails('verify-ca')
await verify_fails('verify-full')

async def test_connection_ssl_unix(self):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)

with self.assertRaisesRegex(asyncpg.InterfaceError,
'can only be enabled for TCP addresses'):
await self.connect(
host='/tmp',
ssl=ssl_context)

async def test_connection_implicit_host(self):
conn_spec = self.get_connection_spec()
con = await asyncpg.connect(
Expand Down

0 comments on commit 8b39a33

Please sign in to comment.