From e8bb3dc0a8dac3b044d6518ae55d3f665e9e474f Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 16 Mar 2017 15:11:28 -0400 Subject: [PATCH] Add rudimentary server capability detection. Add basic server capability detection mechanism based on server version and parameters reported by the server through ParameterStatus messages. This allows altering certain asyncpg behaviour based on the connected server. Specifically, this allows asyncpg to connect to CochroachDB servers. Fixes #87. --- asyncpg/connection.py | 92 ++++++++++++++++++++++++++++------- asyncpg/protocol/settings.pyx | 3 ++ asyncpg/types.py | 2 +- 3 files changed, 78 insertions(+), 19 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4a6e8679..2d3a47cf 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -31,7 +31,8 @@ class Connection: '_type_by_name_stmt', '_top_xact', '_uid', '_aborted', '_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close', '_addr', '_opts', '_command_timeout', '_listeners', - '_server_version', '_intro_query') + '_server_version', '_server_caps', '_intro_query', + '_reset_query') def __init__(self, protocol, transport, loop, addr, opts, *, statement_cache_size, command_timeout): @@ -55,15 +56,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *, self._listeners = {} - ver_string = self._protocol.get_settings().server_version + settings = self._protocol.get_settings() + ver_string = settings.server_version self._server_version = \ serverversion.split_server_version_string(ver_string) + self._server_caps = _detect_server_capabilities( + self._server_version, settings) + if self._server_version < (9, 2): self._intro_query = introspection.INTRO_LOOKUP_TYPES_91 else: self._intro_query = introspection.INTRO_LOOKUP_TYPES + self._reset_query = None + async def add_listener(self, channel, callback): """Add a listener for Postgres notifications. @@ -107,6 +114,7 @@ def get_server_version(self): ServerVersion(major=9, minor=6, micro=1, releaselevel='final', serial=0) + .. versionadded:: 0.8.0 """ return self._server_version @@ -394,22 +402,10 @@ def terminate(self): self._protocol.abort() async def reset(self): - self._listeners = {} - - await self.execute(''' - DO $$ - BEGIN - PERFORM * FROM pg_listening_channels() LIMIT 1; - IF FOUND THEN - UNLISTEN *; - END IF; - END; - $$; - SET SESSION AUTHORIZATION DEFAULT; - RESET ALL; - CLOSE ALL; - SELECT pg_advisory_unlock_all(); - ''') + self._listeners.clear() + reset_query = self._get_reset_query() + if reset_query: + await self.execute(reset_query) def _get_unique_id(self): self._uid += 1 @@ -492,6 +488,35 @@ def _notify(self, pid, channel, payload): 'exception': ex }) + def _get_reset_query(self): + if self._reset_query is not None: + return self._reset_query + + caps = self._server_caps + + _reset_query = '' + if caps.advisory_locks: + _reset_query += 'SELECT pg_advisory_unlock_all();\n' + if caps.cursors: + _reset_query += 'CLOSE ALL;\n' + if caps.notifications and caps.plpgsql: + _reset_query += ''' + DO $$ + BEGIN + PERFORM * FROM pg_listening_channels() LIMIT 1; + IF FOUND THEN + UNLISTEN *; + END IF; + END; + $$; + ''' + if caps.sql_reset: + _reset_query += 'RESET ALL;\n' + + self._reset_query = _reset_query + + return _reset_query + async def connect(dsn=None, *, host=None, port=None, @@ -730,3 +755,34 @@ def _create_future(loop): return asyncio.Future(loop=loop) else: return create_future() + + +ServerCapabilities = collections.namedtuple( + 'ServerCapabilities', + ['advisory_locks', 'cursors', 'notifications', 'plpgsql', 'sql_reset']) +ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + + +def _detect_server_capabilities(server_version, connection_settings): + if hasattr(connection_settings, 'crdb_version'): + # CocroachDB detected. + advisory_locks = False + cursors = False + notifications = False + plpgsql = False + sql_reset = False + else: + # Standard PostgreSQL server assumed. + advisory_locks = True + cursors = True + notifications = True + plpgsql = True + sql_reset = True + + return ServerCapabilities( + advisory_locks=advisory_locks, + cursors=cursors, + notifications=notifications, + plpgsql=plpgsql, + sql_reset=sql_reset + ) diff --git a/asyncpg/protocol/settings.pyx b/asyncpg/protocol/settings.pyx index 9360c459..ec904d9a 100644 --- a/asyncpg/protocol/settings.pyx +++ b/asyncpg/protocol/settings.pyx @@ -60,3 +60,6 @@ cdef class ConnectionSettings: raise AttributeError(name) from None return object.__getattr__(self, name) + + def __repr__(self): + return ''.format(self._settings) diff --git a/asyncpg/types.py b/asyncpg/types.py index b2a8e467..d82b08d3 100644 --- a/asyncpg/types.py +++ b/asyncpg/types.py @@ -10,7 +10,7 @@ __all__ = ( 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', - 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion' + 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion', )