diff --git a/aiopg/connection.py b/aiopg/connection.py index cae0222d..9d870a02 100755 --- a/aiopg/connection.py +++ b/aiopg/connection.py @@ -30,7 +30,12 @@ import psycopg2.extras from .log import logger -from .utils import _ContextManager, create_completed_future, get_running_loop +from .utils import ( + ClosableQueue, + _ContextManager, + create_completed_future, + get_running_loop, +) TIMEOUT = 60.0 @@ -762,6 +767,7 @@ def __init__( self._writing = False self._echo = echo self._notifies = asyncio.Queue() # type: ignore + self._notifies_proxy = ClosableQueue(self._notifies) self._weakref = weakref.ref(self) self._loop.add_reader( self._fileno, self._ready, self._weakref # type: ignore @@ -806,6 +812,7 @@ def _ready(weak_self: "weakref.ref[Any]") -> None: # chain exception otherwise exc2.__cause__ = exc exc = exc2 + self.notifies.close(exc) if waiter is not None and not waiter.done(): waiter.set_exception(exc) else: @@ -1182,9 +1189,9 @@ def __del__(self) -> None: self._loop.call_exception_handler(context) @property - def notifies(self) -> asyncio.Queue: # type: ignore - """Return notification queue.""" - return self._notifies + def notifies(self) -> ClosableQueue: + """Return notification queue (an asyncio.Queue -like object).""" + return self._notifies_proxy async def _get_oids(self) -> Tuple[Any, Any]: cursor = await self.cursor() diff --git a/aiopg/utils.py b/aiopg/utils.py index 5debc5b2..3a45284d 100644 --- a/aiopg/utils.py +++ b/aiopg/utils.py @@ -122,3 +122,53 @@ async def __anext__(self) -> _TObj: finally: self._obj = None raise + + +class ClosableQueue: + """ + Proxy object for an asyncio.Queue that is "closable" + + When the ClosableQueue is closed, with an exception object as parameter, + subsequent or ongoing attempts to read from the queue will result in that + exception being result in that exception being raised. + + Note: closing a queue with exception will still allow to read any items + pending in the queue. The close exception is raised only once all items + are consumed. + """ + + def __init__(self, queue: asyncio.Queue): + self._queue = queue + self._close_exception = asyncio.Future() # type: asyncio.Future[None] + + def close(self, exception: Exception) -> None: + if not self._close_exception.done(): + self._close_exception.set_exception(exception) + + async def get(self) -> Any: + loop = get_running_loop() + get = loop.create_task(self._queue.get()) + + _, pending = await asyncio.wait( + [get, self._close_exception], + return_when=asyncio.FIRST_COMPLETED, + ) + + if get.done(): + return get.result() + get.cancel() + self._close_exception.result() + + def empty(self) -> bool: + return self._queue.empty() + + def qsize(self) -> int: + return self._queue.qsize() + + def get_nowait(self) -> Any: + try: + return self._queue.get_nowait() + except asyncio.QueueEmpty: + if self._close_exception.done(): + self._close_exception.result() + raise diff --git a/docs/core.rst b/docs/core.rst index bf8d2bf9..67c49ed1 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -289,7 +289,7 @@ Example:: .. attribute:: notifies - An :class:`asyncio.Queue` instance for received notifications. + An instance of an :class:`asyncio.Queue` subclass for received notifications. .. seealso:: :ref:`aiopg-core-notifications` @@ -983,6 +983,12 @@ Receiving part should establish listening on notification channel by `LISTEN`_ call and wait notification events from :attr:`Connection.notifies` queue. +.. note:: + + calling `await connection.notifies.get()` may raise a psycopg2 exception + if the underlying connection gets disconnected while you're waiting for + notifications. + There is usage example: .. literalinclude:: ../examples/notify.py diff --git a/examples/notify.py b/examples/notify.py index 3df75665..3d550cde 100644 --- a/examples/notify.py +++ b/examples/notify.py @@ -1,5 +1,7 @@ import asyncio +import psycopg2 + import aiopg dsn = "dbname=aiopg user=aiopg password=passwd host=127.0.0.1" @@ -19,8 +21,12 @@ async def listen(conn): async with conn.cursor() as cur: await cur.execute("LISTEN channel") while True: - msg = await conn.notifies.get() - if msg.payload == "finish": + try: + msg = await conn.notifies.get() + except psycopg2.Error as ex: + print("ERROR: ", ex) + return + if msg.payload == 'finish': return else: print("Receive <-", msg.payload) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3cdcc7ca..47d06657 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -593,3 +593,38 @@ async def test_connection_on_server_restart(connect, pg_server, docker): delay *= 2 else: pytest.fail("Cannot connect to the restarted server") + + +async def test_connection_notify_on_server_restart(connect, pg_server, docker, + loop): + conn = await connect() + + async def read_notifies(): + while True: + await conn.notifies.get() + + reader = loop.create_task(read_notifies()) + await asyncio.sleep(0.1) + + docker.restart(container=pg_server['Id']) + + try: + with pytest.raises(psycopg2.OperationalError): + await asyncio.wait_for(reader, 10) + finally: + conn.close() + reader.cancel() + + # Wait for postgres to be up and running again before moving on + # so as the restart won't affect other tests + delay = 0.001 + for i in range(100): + try: + conn = await connect() + conn.close() + break + except psycopg2.Error: + time.sleep(delay) + delay *= 2 + else: + pytest.fail("Cannot connect to the restarted server") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..1d308834 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,70 @@ +import asyncio + +import pytest + +from aiopg.utils import ClosableQueue + + +async def test_closable_queue_noclose(): + the_queue = asyncio.Queue() + queue = ClosableQueue(the_queue) + assert queue.empty() + assert queue.qsize() == 0 + + await the_queue.put(1) + assert not queue.empty() + assert queue.qsize() == 1 + v = await queue.get() + assert v == 1 + + await the_queue.put(2) + v = queue.get_nowait() + assert v == 2 + + +async def test_closable_queue_close(loop): + the_queue = asyncio.Queue() + queue = ClosableQueue(the_queue) + v1 = None + + async def read(): + nonlocal v1 + v1 = await queue.get() + await queue.get() + + reader = loop.create_task(read()) + await the_queue.put(1) + await asyncio.sleep(0.1) + assert v1 == 1 + + queue.close(RuntimeError("connection closed")) + with pytest.raises(RuntimeError) as excinfo: + await reader + assert excinfo.value.args == ("connection closed",) + + +async def test_closable_queue_close_get_nowait(loop): + the_queue = asyncio.Queue() + queue = ClosableQueue(the_queue) + + await the_queue.put(1) + queue.close(RuntimeError("connection closed")) + + # even when the queue is closed, while there are items in the queu, we + # allow reading them. + assert queue.get_nowait() == 1 + + # when there are no more items in the queue, if there is a close exception + # then it will get raises here + with pytest.raises(RuntimeError) as excinfo: + queue.get_nowait() + assert excinfo.value.args == ("connection closed",) + + +async def test_closable_queue_get_nowait_noclose(loop): + the_queue = asyncio.Queue() + queue = ClosableQueue(the_queue) + await the_queue.put(1) + assert queue.get_nowait() == 1 + with pytest.raises(asyncio.QueueEmpty): + queue.get_nowait()