Skip to content

Commit

Permalink
Add support for SSL connections.
Browse files Browse the repository at this point in the history
Closes: #25.
  • Loading branch information
1st1 committed Apr 4, 2017
1 parent c550388 commit 5836a8f
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 50 deletions.
6 changes: 5 additions & 1 deletion asyncpg/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,14 @@ def get_server_settings(cls):
'log_connections': 'on'
}

@classmethod
def setup_cluster(cls):
cls.cluster = _start_default_cluster(cls.get_server_settings())

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.cluster = _start_default_cluster(cls.get_server_settings())
cls.setup_cluster()

def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
conn_spec = self.cluster.get_connection_spec()
Expand Down
153 changes: 121 additions & 32 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class Connection(metaclass=ConnectionMeta):
'_stmt_cache', '_stmts_to_close',
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_server_caps', '_intro_query',
'_reset_query', '_proxy', '_stmt_exclusive_section')
'_reset_query', '_proxy', '_stmt_exclusive_section',
'_ssl_context')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout,
max_cached_statement_lifetime):
max_cached_statement_lifetime, ssl_context):
self._protocol = protocol
self._transport = transport
self._loop = loop
Expand All @@ -58,6 +59,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,

self._addr = addr
self._opts = opts
self._ssl_context = ssl_context

self._stmt_cache = _StatementCache(
loop=loop,
Expand Down Expand Up @@ -521,12 +523,24 @@ async def cancel():
r, w = await asyncio.open_unix_connection(
self._addr, loop=self._loop)
else:
r, w = await asyncio.open_connection(
*self._addr, loop=self._loop)

sock = w.transport.get_extra_info('socket')
sock.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, 1)
if self._ssl_context:
sock = await _get_ssl_ready_socket(
*self._addr, loop=self._loop)

try:
r, w = await asyncio.open_connection(
sock=sock,
loop=self._loop,
ssl=self._ssl_context,
server_hostname=self._addr[0])
except Exception:
sock.close()
raise

else:
r, w = await asyncio.open_connection(
*self._addr, loop=self._loop)
_set_nodelay(_get_socket(w.transport))

# Pack CancelRequest message
msg = struct.pack('!llll', 16, 80877102,
Expand Down Expand Up @@ -708,9 +722,10 @@ async def connect(dsn=None, *,
statement_cache_size=100,
max_cached_statement_lifetime=300,
command_timeout=None,
ssl=None,
__connection_class__=Connection,
**opts):
"""A coroutine to establish a connection to a PostgreSQL server.
r"""A coroutine to establish a connection to a PostgreSQL server.
Returns a new :class:`~asyncpg.connection.Connection` object.
Expand Down Expand Up @@ -761,6 +776,12 @@ async def connect(dsn=None, *,
the default timeout for operations on this connection
(the default is no timeout).
:param ssl:
pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() <create_default_context_>`_
will be used.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
Expand All @@ -778,42 +799,51 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.10.0
Added ``max_cached_statement_use_count`` parameter.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context: https://docs.python.org/3/library/ssl.html#\
ssl.create_default_context
"""
if loop is None:
loop = asyncio.get_event_loop()

host, port, opts = _parse_connect_params(
addrs, opts = _parse_connect_params(
dsn=dsn, host=host, port=port, user=user, password=password,
database=database, opts=opts)

last_ex = None
if ssl:
for addr in addrs:
if isinstance(addr, str):
# UNIX socket
raise exceptions.InterfaceError(
'`ssl` parameter can only be enabled for TCP addresses, '
'got a UNIX socket path: {!r}'.format(addr))

last_error = None
addr = None
for h in host:
for addr in addrs:
connected = _create_future(loop)
unix = h.startswith('/')

if unix:
# UNIX socket name
addr = h
if '.s.PGSQL.' not in addr:
addr = os.path.join(addr, '.s.PGSQL.{}'.format(port))
conn = loop.create_unix_connection(
lambda: protocol.Protocol(addr, connected, opts, loop),
addr)
proto_factory = lambda: protocol.Protocol(addr, connected, opts, loop)

if isinstance(addr, str):
# UNIX socket
assert ssl is None
connector = loop.create_unix_connection(proto_factory, addr)
elif ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=ssl)
else:
addr = (h, port)
conn = loop.create_connection(
lambda: protocol.Protocol(addr, connected, opts, loop),
h, port)
connector = loop.create_connection(proto_factory, *addr)

try:
tr, pr = await asyncio.wait_for(conn, timeout=timeout, loop=loop)
except (OSError, asyncio.TimeoutError) as ex:
last_ex = ex
tr, pr = await asyncio.wait_for(
connector, timeout=timeout, loop=loop)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
else:
break
else:
raise last_ex
raise last_error

try:
await connected
Expand All @@ -825,12 +855,60 @@ async def connect(dsn=None, *,
pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
command_timeout=command_timeout)
command_timeout=command_timeout, ssl_context=ssl)

pr.set_connection(con)
return con


async def _get_ssl_ready_socket(host, port, *, loop):
reader, writer = await asyncio.open_connection(host, port, loop=loop)

tr = writer.transport
try:
sock = _get_socket(tr)
_set_nodelay(sock)

writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
await writer.drain()
resp = await reader.readexactly(1)

if resp == b'S':
return sock.dup()
else:
raise ConnectionError(
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
host, port))
finally:
tr.close()


async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context):
sock = await _get_ssl_ready_socket(host, port, loop=loop)
try:
return await loop.create_connection(
protocol_factory, sock=sock, ssl=ssl_context,
server_hostname=host)
except Exception:
sock.close()
raise


def _get_socket(transport):
sock = transport.get_extra_info('socket')
if sock is None:
# Shouldn't happen with any asyncio-complaint event loop.
raise ConnectionError(
'could not get the socket for transport {!r}'.format(transport))
return sock


def _set_nodelay(sock):
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)


class _StatementCacheEntry:

__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
Expand Down Expand Up @@ -1116,7 +1194,18 @@ def _parse_connect_params(*, dsn, host, port, user,
'invalid connection parameter {!r}: {!r} (str expected)'
.format(param, opts[param]))

return host, port, opts
addrs = []
for h in host:
if h.startswith('/'):
# UNIX socket name
if '.s.PGSQL.' not in h:
h = os.path.join(h, '.s.PGSQL.{}'.format(port))
addrs.append(h)
else:
# TCP host/port
addrs.append((h, port))

return addrs, opts


def _create_future(loop):
Expand Down
Loading

0 comments on commit 5836a8f

Please sign in to comment.