Skip to content

Commit

Permalink
Add support for coroutine functions as listener callbacks
Browse files Browse the repository at this point in the history
The `Connection.add_listener()`, `Connection.add_log_listener()` and
`Connection.add_termination_listener()` now allow coroutine functions as
callbacks.

Fixes: #567.
  • Loading branch information
elprans committed Aug 9, 2021
1 parent 67ebbc9 commit 81cdff9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 45 deletions.
102 changes: 57 additions & 45 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import collections.abc
import functools
import itertools
import inspect
import os
import sys
import time
import traceback
import typing
import warnings
import weakref

Expand Down Expand Up @@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
:param str channel: Channel to listen on.
:param callable callback:
A callable receiving the following arguments:
A callable or a coroutine function receiving the following
arguments:
**connection**: a Connection the callback is registered with;
**pid**: PID of the Postgres server that sent the notification;
**channel**: name of the channel the notification was sent to;
**payload**: the payload.
.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
self._check_open()
if channel not in self._listeners:
await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
self._listeners[channel] = set()
self._listeners[channel].add(callback)
self._listeners[channel].add(_Callback.from_callable(callback))

async def remove_listener(self, channel, callback):
"""Remove a listening callback on the specified channel."""
if self.is_closed():
return
if channel not in self._listeners:
return
if callback not in self._listeners[channel]:
cb = _Callback.from_callable(callback)
if cb not in self._listeners[channel]:
return
self._listeners[channel].remove(callback)
self._listeners[channel].remove(cb)
if not self._listeners[channel]:
del self._listeners[channel]
await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel)))
Expand All @@ -166,44 +173,51 @@ def add_log_listener(self, callback):
DEBUG, INFO, or LOG.
:param callable callback:
A callable receiving the following arguments:
A callable or a coroutine function receiving the following
arguments:
**connection**: a Connection the callback is registered with;
**message**: the `exceptions.PostgresLogMessage` message.
.. versionadded:: 0.12.0
.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
if self.is_closed():
raise exceptions.InterfaceError('connection is closed')
self._log_listeners.add(callback)
self._log_listeners.add(_Callback.from_callable(callback))

def remove_log_listener(self, callback):
"""Remove a listening callback for log messages.
.. versionadded:: 0.12.0
"""
self._log_listeners.discard(callback)
self._log_listeners.discard(_Callback.from_callable(callback))

def add_termination_listener(self, callback):
"""Add a listener that will be called when the connection is closed.
:param callable callback:
A callable receiving one argument:
A callable or a coroutine function receiving one argument:
**connection**: a Connection the callback is registered with.
.. versionadded:: 0.21.0
.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
self._termination_listeners.add(callback)
self._termination_listeners.add(_Callback.from_callable(callback))

def remove_termination_listener(self, callback):
"""Remove a listening callback for connection termination.
:param callable callback:
The callable that was passed to
The callable or coroutine function that was passed to
:meth:`Connection.add_termination_listener`.
.. versionadded:: 0.21.0
"""
self._termination_listeners.discard(callback)
self._termination_listeners.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
Expand Down Expand Up @@ -1430,35 +1444,21 @@ def _process_log_message(self, fields, last_query):

con_ref = self._unwrap()
for cb in self._log_listeners:
self._loop.call_soon(
self._call_log_listener, cb, con_ref, message)

def _call_log_listener(self, cb, con_ref, message):
try:
cb(con_ref, message)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg log message '
'listener callback {!r}'.format(cb),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, message))
else:
self._loop.call_soon(cb.cb, con_ref, message)

def _call_termination_listeners(self):
if not self._termination_listeners:
return

con_ref = self._unwrap()
for cb in self._termination_listeners:
try:
cb(con_ref)
except Exception as ex:
self._loop.call_exception_handler({
'message': (
'Unhandled exception in asyncpg connection '
'termination listener callback {!r}'.format(cb)
),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref))
else:
self._loop.call_soon(cb.cb, con_ref)

self._termination_listeners.clear()

Expand All @@ -1468,18 +1468,10 @@ def _process_notification(self, pid, channel, payload):

con_ref = self._unwrap()
for cb in self._listeners[channel]:
self._loop.call_soon(
self._call_listener, cb, con_ref, pid, channel, payload)

def _call_listener(self, cb, con_ref, pid, channel, payload):
try:
cb(con_ref, pid, channel, payload)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg notification '
'listener callback {!r}'.format(cb),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, pid, channel, payload))
else:
self._loop.call_soon(cb.cb, con_ref, pid, channel, payload)

def _unwrap(self):
if self._proxy is None:
Expand Down Expand Up @@ -2154,6 +2146,26 @@ def _maybe_cleanup(self):
self._on_remove(old_entry._statement)


class _Callback(typing.NamedTuple):

cb: typing.Callable[..., None]
is_async: bool

@classmethod
def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback':
if inspect.iscoroutinefunction(cb):
is_async = True
elif callable(cb):
is_async = False
else:
raise exceptions.InterfaceError(
'expected a callable or an `async def` function,'
'got {!r}'.format(cb)
)

return cls(cb, is_async)


class _Atomic:
__slots__ = ('_acquired',)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@ async def test_listen_01(self):

q1 = asyncio.Queue()
q2 = asyncio.Queue()
q3 = asyncio.Queue()

def listener1(*args):
q1.put_nowait(args)

def listener2(*args):
q2.put_nowait(args)

async def async_listener3(*args):
q3.put_nowait(args)

await con.add_listener('test', listener1)
await con.add_listener('test', listener2)
await con.add_listener('test', async_listener3)

await con.execute("NOTIFY test, 'aaaa'")

Expand All @@ -41,8 +46,12 @@ def listener2(*args):
self.assertEqual(
await q2.get(),
(con, con.get_server_pid(), 'test', 'aaaa'))
self.assertEqual(
await q3.get(),
(con, con.get_server_pid(), 'test', 'aaaa'))

await con.remove_listener('test', listener2)
await con.remove_listener('test', async_listener3)

await con.execute("NOTIFY test, 'aaaa'")

Expand Down Expand Up @@ -117,13 +126,20 @@ class TestLogListeners(tb.ConnectedTestCase):
})
async def test_log_listener_01(self):
q1 = asyncio.Queue()
q2 = asyncio.Queue()

def notice_callb(con, message):
# Message fields depend on PG version, hide some values.
dct = message.as_dict()
del dct['server_source_line']
q1.put_nowait((con, type(message), dct))

async def async_notice_callb(con, message):
# Message fields depend on PG version, hide some values.
dct = message.as_dict()
del dct['server_source_line']
q2.put_nowait((con, type(message), dct))

async def raise_notice():
await self.con.execute(
"""DO $$
Expand All @@ -140,6 +156,7 @@ async def raise_warning():

con = self.con
con.add_log_listener(notice_callb)
con.add_log_listener(async_notice_callb)

expected_msg = {
'context': 'PL/pgSQL function inline_code_block line 2 at RAISE',
Expand Down Expand Up @@ -182,7 +199,21 @@ async def raise_warning():
msg,
(con, exceptions.PostgresWarning, expected_msg_warn))

msg = await q2.get()
msg[2].pop('server_source_filename', None)
self.assertEqual(
msg,
(con, exceptions.PostgresLogMessage, expected_msg_notice))

msg = await q2.get()
msg[2].pop('server_source_filename', None)
self.assertEqual(
msg,
(con, exceptions.PostgresWarning, expected_msg_warn))

con.remove_log_listener(notice_callb)
con.remove_log_listener(async_notice_callb)

await raise_notice()
self.assertTrue(q1.empty())

Expand Down Expand Up @@ -291,19 +322,26 @@ class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):
async def test_connection_termination_callback_called_on_remote(self):

called = False
async_called = False

def close_cb(con):
nonlocal called
called = True

async def async_close_cb(con):
nonlocal async_called
async_called = True

con = await self.connect()
con.add_termination_listener(close_cb)
con.add_termination_listener(async_close_cb)
self.proxy.close_all_connections()
try:
await con.fetchval('SELECT 1')
except Exception:
pass
self.assertTrue(called)
self.assertTrue(async_called)

async def test_connection_termination_callback_called_on_local(self):

Expand All @@ -316,4 +354,5 @@ def close_cb(con):
con = await self.connect()
con.add_termination_listener(close_cb)
await con.close()
await asyncio.sleep(0)
self.assertTrue(called)

0 comments on commit 81cdff9

Please sign in to comment.