Skip to content

Commit

Permalink
pool: Track connections and prohibit using them after release.
Browse files Browse the repository at this point in the history
Connection pool now wraps all connections in `PooledConnectionProxy`
objects to raise `InterfaceError` if they are used after being
released back to the pool.  We also check if connection passed
to `pool.release` actually belong to the pool and correctly handle
multiple calls to `pool.release` with the same connection object.

`PooledConnectionProxy` transparently wraps Connection instances,
exposing all Connection public API.

`isinstance(asyncpg.connection.Connection)` is `True` for Instances
of `PooledConnectionProxy` class.
  • Loading branch information
1st1 committed Mar 29, 2017
1 parent 537c8c9 commit 3bf6103
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 22 deletions.
55 changes: 39 additions & 16 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@
import urllib.parse

from . import cursor
from . import exceptions
from . import introspection
from . import prepared_stmt
from . import protocol
from . import serverversion
from . import transaction


class Connection:
class ConnectionMeta(type):

def __instancecheck__(cls, instance):
mro = type(instance).__mro__
return Connection in mro or _ConnectionProxy in mro


class Connection(metaclass=ConnectionMeta):
"""A representation of a database session.
Connections are created by calling :func:`~asyncpg.connection.connect`.
Expand All @@ -32,7 +40,7 @@ class Connection:
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_server_caps', '_intro_query',
'_reset_query')
'_reset_query', '_proxy')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout):
Expand Down Expand Up @@ -70,6 +78,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
self._intro_query = introspection.INTRO_LOOKUP_TYPES

self._reset_query = None
self._proxy = None

async def add_listener(self, channel, callback):
"""Add a listener for Postgres notifications.
Expand Down Expand Up @@ -478,9 +487,18 @@ def _notify(self, pid, channel, payload):
if channel not in self._listeners:
return

if self._proxy is None:
con_ref = self
else:
# `_proxy` is not None when the connection is a member
# of a connection pool. Which means that the user is working
# with a PooledConnectionProxy instance, and expects to see it
# (and not the actual Connection) in their event callbacks.
con_ref = self._proxy

for cb in self._listeners[channel]:
try:
cb(self, pid, channel, payload)
cb(con_ref, pid, channel, payload)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg notification '
Expand Down Expand Up @@ -517,6 +535,14 @@ def _get_reset_query(self):

return _reset_query

def _set_proxy(self, proxy):
if self._proxy is not None and proxy is not None:
# Should not happen unless there is a bug in `Pool`.
raise exceptions.InterfaceError(
'internal asyncpg error: connection is already proxied')

self._proxy = proxy


async def connect(dsn=None, *,
host=None, port=None,
Expand All @@ -526,7 +552,7 @@ async def connect(dsn=None, *,
timeout=60,
statement_cache_size=100,
command_timeout=None,
connection_class=Connection,
__connection_class__=Connection,
**opts):
"""A coroutine to establish a connection to a PostgreSQL server.
Expand Down Expand Up @@ -564,11 +590,7 @@ async def connect(dsn=None, *,
:param float command_timeout: the default timeout for operations on
this connection (the default is no timeout).
:param builtins.type connection_class: A class used to represent
the connection.
Defaults to :class:`~asyncpg.connection.Connection`.
:return: A *connection_class* instance.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
Expand All @@ -582,10 +604,6 @@ async def connect(dsn=None, *,
... print(types)
>>> asyncio.get_event_loop().run_until_complete(run())
[<Record typname='bool' typnamespace=11 ...
.. versionadded:: 0.10.0
*connection_class* argument.
"""
if loop is None:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -629,13 +647,18 @@ async def connect(dsn=None, *,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
command_timeout=command_timeout)
con = __connection_class__(pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
command_timeout=command_timeout)
pr.set_connection(con)
return con


class _ConnectionProxy:
# Base class to enable `isinstance(Connection)` check.
__slots__ = ()


def _parse_connect_params(*, dsn, host, port, user,
password, database, opts):

Expand Down
89 changes: 89 additions & 0 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,80 @@


import asyncio
import functools

from . import connection
from . import exceptions


class PooledConnectionProxyMeta(type):

def __new__(mcls, name, bases, dct, *, wrap=False):
if wrap:
def get_wrapper(methname):
meth = getattr(connection.Connection, methname)

def wrapper(self, *args, **kwargs):
return self._dispatch(meth, args, kwargs)

return wrapper

for attrname in dir(connection.Connection):
if attrname.startswith('_') or attrname in dct:
continue
wrapper = get_wrapper(attrname)
wrapper = functools.update_wrapper(
wrapper, getattr(connection.Connection, attrname))
dct[attrname] = wrapper

if '__doc__' not in dct:
dct['__doc__'] = connection.Connection.__doc__

return super().__new__(mcls, name, bases, dct)

def __init__(cls, name, bases, dct, *, wrap=False):
# Needed for Python 3.5 to handle `wrap` class keyword argument.
super().__init__(name, bases, dct)


class PooledConnectionProxy(connection._ConnectionProxy,
metaclass=PooledConnectionProxyMeta,
wrap=True):

__slots__ = ('_con', '_owner')

def __init__(self, owner: 'Pool', con: connection.Connection):
self._con = con
self._owner = owner
con._set_proxy(self)

def _unwrap(self) -> connection.Connection:
if self._con is None:
raise exceptions.InterfaceError(
'internal asyncpg error: cannot unwrap pooled connection')

con, self._con = self._con, None
con._set_proxy(None)
return con

def _dispatch(self, meth, args, kwargs):
if self._con is None:
raise exceptions.InterfaceError(
'cannot call Connection.{}(): '
'connection has been released back to the pool'.format(
meth.__name__))

return meth(self._con, *args, **kwargs)

def __repr__(self):
if self._con is None:
return '<{classname} [released] {id:#x}>'.format(
classname=self.__class__.__name__, id=id(self))
else:
return '<{classname} {con!r} {id:#x}>'.format(
classname=self.__class__.__name__, con=self._con, id=id(self))


class Pool:
"""A connection pool.
Expand Down Expand Up @@ -168,6 +237,8 @@ async def _acquire_impl(self):
else:
con = await self._queue.get()

con = PooledConnectionProxy(self, con)

if self._setup is not None:
try:
await self._setup(con)
Expand All @@ -179,6 +250,20 @@ async def _acquire_impl(self):

async def release(self, connection):
"""Release a database connection back to the pool."""

if (connection.__class__ is not PooledConnectionProxy or
connection._owner is not self):
raise exceptions.InterfaceError(
'Pool.release() received invalid connection: '
'{connection!r} is not a member of this pool'.format(
connection=connection))

if connection._con is None:
# Already released, do nothing.
return

connection = connection._unwrap()

# Use asyncio.shield() to guarantee that task cancellation
# does not prevent the connection from being returned to the
# pool properly.
Expand Down Expand Up @@ -325,6 +410,10 @@ def create_pool(dsn=None, *,
:param loop: An asyncio event loop instance. If ``None``, the default
event loop will be used.
:return: An instance of :class:`~asyncpg.pool.Pool`.
.. versionchanged:: 0.10.0
An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any
attempted operation on a released connection.
"""
return Pool(dsn,
min_size=min_size, max_size=max_size,
Expand Down
12 changes: 10 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import asyncpg
from asyncpg import _testbase as tb
from asyncpg.connection import _parse_connect_params
from asyncpg import connection
from asyncpg.serverversion import split_server_version_string

_system = platform.uname().system
Expand Down Expand Up @@ -355,7 +355,7 @@ def run_testcase(self, testcase):
if expected_error:
es.enter_context(self.assertRaisesRegex(*expected_error))

result = _parse_connect_params(
result = connection._parse_connect_params(
dsn=dsn, host=host, port=port, user=user, password=password,
database=database, opts=opts)

Expand Down Expand Up @@ -411,3 +411,11 @@ def test_test_connect_params_run_testcase(self):
def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)


class TestConnection(tb.ConnectedTestCase):

async def test_connection_isinstance(self):
self.assertTrue(isinstance(self.con, connection.Connection))
self.assertTrue(isinstance(self.con, object))
self.assertFalse(isinstance(self.con, list))
Loading

0 comments on commit 3bf6103

Please sign in to comment.