Skip to content

Commit

Permalink
Shield Pool.release() from task cancellation.
Browse files Browse the repository at this point in the history
Use asyncio.shield() to guarantee that task cancellation
does not prevent the connection from being returned to the
pool properly.

Fixes: #97.
  • Loading branch information
elprans committed Mar 29, 2017
1 parent d42608f commit 537c8c9
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 14 deletions.
20 changes: 18 additions & 2 deletions asyncpg/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def _shutdown_cluster(cluster):
cluster.destroy()


def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
**connect_kwargs)


class ClusterTestCase(TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -136,10 +152,10 @@ def setUpClass(cls):
'log_connections': 'on'
})

def create_pool(self, **kwargs):
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
conn_spec = self.cluster.get_connection_spec()
conn_spec.update(kwargs)
return pg_pool.create_pool(loop=self.loop, **conn_spec)
return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec)

@classmethod
def start_cluster(cls, ClusterCls, *,
Expand Down
19 changes: 14 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ async def connect(dsn=None, *,
timeout=60,
statement_cache_size=100,
command_timeout=None,
connection_class=Connection,
**opts):
"""A coroutine to establish a connection to a PostgreSQL server.
Expand Down Expand Up @@ -558,12 +559,16 @@ async def connect(dsn=None, *,
:param float timeout: connection timeout in seconds.
:param int statement_cache_size: the size of prepared statement LRU cache.
:param float command_timeout: the default timeout for operations on
this connection (the default is no timeout).
:param int statement_cache_size: the size of prepared statement LRU cache.
:param builtins.type connection_class: A class used to represent
the connection.
Defaults to :class:`~asyncpg.connection.Connection`.
:return: A :class:`~asyncpg.connection.Connection` instance.
:return: A *connection_class* instance.
Example:
Expand All @@ -577,6 +582,10 @@ async def connect(dsn=None, *,
... print(types)
>>> asyncio.get_event_loop().run_until_complete(run())
[<Record typname='bool' typnamespace=11 ...
.. versionadded:: 0.10.0
*connection_class* argument.
"""
if loop is None:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -620,9 +629,9 @@ async def connect(dsn=None, *,
tr.close()
raise

con = Connection(pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
command_timeout=command_timeout)
con = connection_class(pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
command_timeout=command_timeout)
pr.set_connection(con)
return con

Expand Down
23 changes: 16 additions & 7 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def __init__(self, *connect_args,

self._closed = False

async def _connect(self, *args, **kwargs):
return await connection.connect(*args, **kwargs)

async def _new_connection(self):
if self._working_addr is None:
con = await connection.connect(*self._connect_args,
loop=self._loop,
**self._connect_kwargs)

con = await self._connect(*self._connect_args,
loop=self._loop,
**self._connect_kwargs)
self._working_addr = con._addr
self._working_opts = con._opts

Expand All @@ -86,9 +88,9 @@ async def _new_connection(self):
else:
host, port = self._working_addr

con = await connection.connect(host=host, port=port,
loop=self._loop,
**self._working_opts)
con = await self._connect(host=host, port=port,
loop=self._loop,
**self._working_opts)

if self._init is not None:
await self._init(con)
Expand Down Expand Up @@ -177,6 +179,13 @@ async def _acquire_impl(self):

async def release(self, connection):
"""Release a database connection back to the pool."""
# Use asyncio.shield() to guarantee that task cancellation
# does not prevent the connection from being returned to the
# pool properly.
return await asyncio.shield(self._release_impl(connection),
loop=self._loop)

async def _release_impl(self, connection):
self._check_init()
if connection.is_closed():
self._con_count -= 1
Expand Down
35 changes: 35 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest

from asyncpg import _testbase as tb
from asyncpg import connection as pg_connection
from asyncpg import cluster as pg_cluster
from asyncpg import pool as pg_pool

Expand All @@ -24,6 +25,19 @@
POOL_NOMINAL_TIMEOUT = 0.1


class SlowResetConnection(pg_connection.Connection):
"""Connection class to simulate races with Connection.reset()."""
async def reset(self):
await asyncio.sleep(0.2, loop=self._loop)
return await super().reset()


class SlowResetConnectionPool(pg_pool.Pool):
async def _connect(self, *args, **kwargs):
return await pg_connection.connect(
*args, connection_class=SlowResetConnection, **kwargs)


class TestPool(tb.ConnectedTestCase):

async def test_pool_01(self):
Expand Down Expand Up @@ -186,6 +200,27 @@ async def worker():
self.cluster.trust_local_connections()
self.cluster.reload()

async def test_pool_handles_cancel_in_release(self):
# Use SlowResetConnectionPool to simulate
# the Task.cancel() and __aexit__ race.
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1,
pool_class=SlowResetConnectionPool)

async def worker():
async with pool.acquire():
pass

task = self.loop.create_task(worker())
# Let the worker() run.
await asyncio.sleep(0.1, loop=self.loop)
# Cancel the worker.
task.cancel()
# Wait to make sure the cleanup has completed.
await asyncio.sleep(0.4, loop=self.loop)
# Check that the connection has been returned to the pool.
self.assertEqual(pool._queue.qsize(), 1)


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
class TestHostStandby(tb.ConnectedTestCase):
Expand Down

0 comments on commit 537c8c9

Please sign in to comment.