Skip to content

Commit

Permalink
Add rudimentary server capability detection.
Browse files Browse the repository at this point in the history
Add basic server capability detection mechanism based on server version
and parameters reported by the server through ParameterStatus messages.
This allows altering certain asyncpg behaviour based on the connected
server.

Specifically, this allows asyncpg to connect to CochroachDB servers.

Fixes #87.
  • Loading branch information
elprans authored and Elvis Pranskevichus committed Mar 17, 2017
1 parent 8d17ecc commit e8bb3dc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 19 deletions.
92 changes: 74 additions & 18 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Connection:
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_intro_query')
'_server_version', '_server_caps', '_intro_query',
'_reset_query')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout):
Expand All @@ -55,15 +56,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *,

self._listeners = {}

ver_string = self._protocol.get_settings().server_version
settings = self._protocol.get_settings()
ver_string = settings.server_version
self._server_version = \
serverversion.split_server_version_string(ver_string)

self._server_caps = _detect_server_capabilities(
self._server_version, settings)

if self._server_version < (9, 2):
self._intro_query = introspection.INTRO_LOOKUP_TYPES_91
else:
self._intro_query = introspection.INTRO_LOOKUP_TYPES

self._reset_query = None

async def add_listener(self, channel, callback):
"""Add a listener for Postgres notifications.
Expand Down Expand Up @@ -107,6 +114,7 @@ def get_server_version(self):
ServerVersion(major=9, minor=6, micro=1,
releaselevel='final', serial=0)
.. versionadded:: 0.8.0
"""
return self._server_version

Expand Down Expand Up @@ -394,22 +402,10 @@ def terminate(self):
self._protocol.abort()

async def reset(self):
self._listeners = {}

await self.execute('''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
IF FOUND THEN
UNLISTEN *;
END IF;
END;
$$;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
CLOSE ALL;
SELECT pg_advisory_unlock_all();
''')
self._listeners.clear()
reset_query = self._get_reset_query()
if reset_query:
await self.execute(reset_query)

def _get_unique_id(self):
self._uid += 1
Expand Down Expand Up @@ -492,6 +488,35 @@ def _notify(self, pid, channel, payload):
'exception': ex
})

def _get_reset_query(self):
if self._reset_query is not None:
return self._reset_query

caps = self._server_caps

_reset_query = ''
if caps.advisory_locks:
_reset_query += 'SELECT pg_advisory_unlock_all();\n'
if caps.cursors:
_reset_query += 'CLOSE ALL;\n'
if caps.notifications and caps.plpgsql:
_reset_query += '''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
IF FOUND THEN
UNLISTEN *;
END IF;
END;
$$;
'''
if caps.sql_reset:
_reset_query += 'RESET ALL;\n'

self._reset_query = _reset_query

return _reset_query


async def connect(dsn=None, *,
host=None, port=None,
Expand Down Expand Up @@ -730,3 +755,34 @@ def _create_future(loop):
return asyncio.Future(loop=loop)
else:
return create_future()


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'cursors', 'notifications', 'plpgsql', 'sql_reset'])
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'


def _detect_server_capabilities(server_version, connection_settings):
if hasattr(connection_settings, 'crdb_version'):
# CocroachDB detected.
advisory_locks = False
cursors = False
notifications = False
plpgsql = False
sql_reset = False
else:
# Standard PostgreSQL server assumed.
advisory_locks = True
cursors = True
notifications = True
plpgsql = True
sql_reset = True

return ServerCapabilities(
advisory_locks=advisory_locks,
cursors=cursors,
notifications=notifications,
plpgsql=plpgsql,
sql_reset=sql_reset
)
3 changes: 3 additions & 0 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ cdef class ConnectionSettings:
raise AttributeError(name) from None

return object.__getattr__(self, name)

def __repr__(self):
return '<ConnectionSettings {!r}>'.format(self._settings)
2 changes: 1 addition & 1 deletion asyncpg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__all__ = (
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion'
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
)


Expand Down

0 comments on commit e8bb3dc

Please sign in to comment.