From 537c8c9f4850899f0618d1a031641dfd3ed6986c Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Wed, 29 Mar 2017 11:18:45 -0400 Subject: [PATCH] Shield Pool.release() from task cancellation. Use asyncio.shield() to guarantee that task cancellation does not prevent the connection from being returned to the pool properly. Fixes: #97. --- asyncpg/_testbase.py | 20 ++++++++++++++++++-- asyncpg/connection.py | 19 ++++++++++++++----- asyncpg/pool.py | 23 ++++++++++++++++------- tests/test_pool.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 14 deletions(-) diff --git a/asyncpg/_testbase.py b/asyncpg/_testbase.py index daf9cdf1..71ea3053 100644 --- a/asyncpg/_testbase.py +++ b/asyncpg/_testbase.py @@ -128,6 +128,22 @@ def _shutdown_cluster(cluster): cluster.destroy() +def create_pool(dsn=None, *, + min_size=10, + max_size=10, + max_queries=50000, + setup=None, + init=None, + loop=None, + pool_class=pg_pool.Pool, + **connect_kwargs): + return pool_class( + dsn, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, + **connect_kwargs) + + class ClusterTestCase(TestCase): @classmethod def setUpClass(cls): @@ -136,10 +152,10 @@ def setUpClass(cls): 'log_connections': 'on' }) - def create_pool(self, **kwargs): + def create_pool(self, pool_class=pg_pool.Pool, **kwargs): conn_spec = self.cluster.get_connection_spec() conn_spec.update(kwargs) - return pg_pool.create_pool(loop=self.loop, **conn_spec) + return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec) @classmethod def start_cluster(cls, ClusterCls, *, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2d3a47cf..b5159956 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -526,6 +526,7 @@ async def connect(dsn=None, *, timeout=60, statement_cache_size=100, command_timeout=None, + connection_class=Connection, **opts): """A coroutine to establish a connection to a PostgreSQL server. @@ -558,12 +559,16 @@ async def connect(dsn=None, *, :param float timeout: connection timeout in seconds. + :param int statement_cache_size: the size of prepared statement LRU cache. + :param float command_timeout: the default timeout for operations on this connection (the default is no timeout). - :param int statement_cache_size: the size of prepared statement LRU cache. + :param builtins.type connection_class: A class used to represent + the connection. + Defaults to :class:`~asyncpg.connection.Connection`. - :return: A :class:`~asyncpg.connection.Connection` instance. + :return: A *connection_class* instance. Example: @@ -577,6 +582,10 @@ async def connect(dsn=None, *, ... print(types) >>> asyncio.get_event_loop().run_until_complete(run()) [