diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4a6e8679..2d3a47cf 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -31,7 +31,8 @@ class Connection: '_type_by_name_stmt', '_top_xact', '_uid', '_aborted', '_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close', '_addr', '_opts', '_command_timeout', '_listeners', - '_server_version', '_intro_query') + '_server_version', '_server_caps', '_intro_query', + '_reset_query') def __init__(self, protocol, transport, loop, addr, opts, *, statement_cache_size, command_timeout): @@ -55,15 +56,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *, self._listeners = {} - ver_string = self._protocol.get_settings().server_version + settings = self._protocol.get_settings() + ver_string = settings.server_version self._server_version = \ serverversion.split_server_version_string(ver_string) + self._server_caps = _detect_server_capabilities( + self._server_version, settings) + if self._server_version < (9, 2): self._intro_query = introspection.INTRO_LOOKUP_TYPES_91 else: self._intro_query = introspection.INTRO_LOOKUP_TYPES + self._reset_query = None + async def add_listener(self, channel, callback): """Add a listener for Postgres notifications. @@ -107,6 +114,7 @@ def get_server_version(self): ServerVersion(major=9, minor=6, micro=1, releaselevel='final', serial=0) + .. versionadded:: 0.8.0 """ return self._server_version @@ -394,22 +402,10 @@ def terminate(self): self._protocol.abort() async def reset(self): - self._listeners = {} - - await self.execute(''' - DO $$ - BEGIN - PERFORM * FROM pg_listening_channels() LIMIT 1; - IF FOUND THEN - UNLISTEN *; - END IF; - END; - $$; - SET SESSION AUTHORIZATION DEFAULT; - RESET ALL; - CLOSE ALL; - SELECT pg_advisory_unlock_all(); - ''') + self._listeners.clear() + reset_query = self._get_reset_query() + if reset_query: + await self.execute(reset_query) def _get_unique_id(self): self._uid += 1 @@ -492,6 +488,35 @@ def _notify(self, pid, channel, payload): 'exception': ex }) + def _get_reset_query(self): + if self._reset_query is not None: + return self._reset_query + + caps = self._server_caps + + _reset_query = '' + if caps.advisory_locks: + _reset_query += 'SELECT pg_advisory_unlock_all();\n' + if caps.cursors: + _reset_query += 'CLOSE ALL;\n' + if caps.notifications and caps.plpgsql: + _reset_query += ''' + DO $$ + BEGIN + PERFORM * FROM pg_listening_channels() LIMIT 1; + IF FOUND THEN + UNLISTEN *; + END IF; + END; + $$; + ''' + if caps.sql_reset: + _reset_query += 'RESET ALL;\n' + + self._reset_query = _reset_query + + return _reset_query + async def connect(dsn=None, *, host=None, port=None, @@ -730,3 +755,34 @@ def _create_future(loop): return asyncio.Future(loop=loop) else: return create_future() + + +ServerCapabilities = collections.namedtuple( + 'ServerCapabilities', + ['advisory_locks', 'cursors', 'notifications', 'plpgsql', 'sql_reset']) +ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + + +def _detect_server_capabilities(server_version, connection_settings): + if hasattr(connection_settings, 'crdb_version'): + # CocroachDB detected. + advisory_locks = False + cursors = False + notifications = False + plpgsql = False + sql_reset = False + else: + # Standard PostgreSQL server assumed. + advisory_locks = True + cursors = True + notifications = True + plpgsql = True + sql_reset = True + + return ServerCapabilities( + advisory_locks=advisory_locks, + cursors=cursors, + notifications=notifications, + plpgsql=plpgsql, + sql_reset=sql_reset + ) diff --git a/asyncpg/protocol/settings.pyx b/asyncpg/protocol/settings.pyx index 9360c459..ec904d9a 100644 --- a/asyncpg/protocol/settings.pyx +++ b/asyncpg/protocol/settings.pyx @@ -60,3 +60,6 @@ cdef class ConnectionSettings: raise AttributeError(name) from None return object.__getattr__(self, name) + + def __repr__(self): + return ''.format(self._settings) diff --git a/asyncpg/types.py b/asyncpg/types.py index b2a8e467..d82b08d3 100644 --- a/asyncpg/types.py +++ b/asyncpg/types.py @@ -10,7 +10,7 @@ __all__ = ( 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', - 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion' + 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion', )