Skip to content

Commit

Permalink
Invalidate statement cache on schema changes affecting statement result.
Browse files Browse the repository at this point in the history
PostgreSQL will raise an exception when it detects that the result type of the
query has changed from when the statement was prepared.  This may happen, for
example, after an ALTER TABLE or SET search_path.

When this happens, and there is no transaction running, we can simply
re-prepare the statement and try again.

If the transaction _is_ running, this error will put it into an error state,
and we have no choice but to raise an exception.  The original error is
somewhat cryptic, so we raise a custom InvalidCachedStatementError with the
original server exception as context.

In either case we clear the statement cache for this connection and all other
connections of the pool this connection belongs to (if any).

See #72 and #76 for discussion.

Fixes: #72.
  • Loading branch information
elprans committed Mar 30, 2017
1 parent 0dd8fb6 commit 749d857
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 60 deletions.
71 changes: 58 additions & 13 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
if not args:
return await self._protocol.query(query, timeout)

stmt = await self._get_statement(query, timeout)
_, status, _ = await self._protocol.bind_execute(stmt, args, '', 0,
True, timeout)
_, status, _ = await self._do_execute(query, args, 0, timeout, True)
return status.decode()

async def executemany(self, command: str, args, timeout: float=None):
Expand Down Expand Up @@ -283,10 +281,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
:return list: A list of :class:`Record` instances.
"""
stmt = await self._get_statement(query, timeout)
data = await self._protocol.bind_execute(stmt, args, '', 0,
False, timeout)
return data
return await self._do_execute(query, args, 0, timeout)

async def fetchval(self, query, *args, column=0, timeout=None):
"""Run a query and return a value in the first row.
Expand All @@ -302,9 +297,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
:return: The value of the specified column of the first record.
"""
stmt = await self._get_statement(query, timeout)
data = await self._protocol.bind_execute(stmt, args, '', 1,
False, timeout)
data = await self._do_execute(query, args, 1, timeout)
if not data:
return None
return data[0][column]
Expand All @@ -318,9 +311,7 @@ async def fetchrow(self, query, *args, timeout=None):
:return: The first row as a :class:`Record` instance.
"""
stmt = await self._get_statement(query, timeout)
data = await self._protocol.bind_execute(stmt, args, '', 1,
False, timeout)
data = await self._do_execute(query, args, 1, timeout)
if not data:
return None
return data[0]
Expand Down Expand Up @@ -551,6 +542,60 @@ def _set_proxy(self, proxy):

self._proxy = proxy

def _drop_local_statement_cache(self):
self._stmt_cache.clear()

def _drop_global_statement_cache(self):
if self._proxy is not None:
# This connection is a member of a pool, so we delegate
# the cache drop to the pool.
pool = self._proxy._holder._pool
pool._drop_statement_cache()
else:
self._drop_local_statement_cache()

async def _do_execute(self, query, args, limit, timeout,
return_status=False):
stmt = await self._get_statement(query, timeout)

try:
result = await self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)

except exceptions.InvalidCachedStatementError as e:
# PostgreSQL will raise an exception when it detects
# that the result type of the query has changed from
# when the statement was prepared. This may happen,
# for example, after an ALTER TABLE or SET search_path.
#
# When this happens, and there is no transaction running,
# we can simply re-prepare the statement and try once
# again. We deliberately retry only once as this is
# supposed to be a rare occurrence.
#
# If the transaction _is_ running, this error will put it
# into an error state, and we have no choice but to
# re-raise the exception.
#
# In either case we clear the statement cache for this
# connection and all other connections of the pool this
# connection belongs to (if any).
#
# See https://github.com/MagicStack/asyncpg/issues/72
# and https://github.com/MagicStack/asyncpg/issues/76
# for discussion.
#
self._drop_global_statement_cache()

if self._protocol.is_in_transaction():
raise
else:
stmt = await self._get_statement(query, timeout)
result = await self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)

return result


async def connect(dsn=None, *,
host=None, port=None,
Expand Down
30 changes: 14 additions & 16 deletions asyncpg/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# Copyright (C) 2016-present the ayncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


# GENERATED FROM postgresql/src/backend/utils/errcodes.txt
# DO NOT MODIFY, use tools/generate_exceptions.py to update

Expand Down Expand Up @@ -92,6 +85,10 @@ class FeatureNotSupportedError(_base.PostgresError):
sqlstate = '0A000'


class InvalidCachedStatementError(FeatureNotSupportedError):
pass


class InvalidTransactionInitiationError(_base.PostgresError):
sqlstate = '0B000'

Expand Down Expand Up @@ -1025,15 +1022,16 @@ class IndexCorruptedError(InternalServerError):
'InvalidArgumentForPowerFunctionError',
'InvalidArgumentForWidthBucketFunctionError',
'InvalidAuthorizationSpecificationError',
'InvalidBinaryRepresentationError', 'InvalidCatalogNameError',
'InvalidCharacterValueForCastError', 'InvalidColumnDefinitionError',
'InvalidColumnReferenceError', 'InvalidCursorDefinitionError',
'InvalidCursorNameError', 'InvalidCursorStateError',
'InvalidDatabaseDefinitionError', 'InvalidDatetimeFormatError',
'InvalidEscapeCharacterError', 'InvalidEscapeOctetError',
'InvalidEscapeSequenceError', 'InvalidForeignKeyError',
'InvalidFunctionDefinitionError', 'InvalidGrantOperationError',
'InvalidGrantorError', 'InvalidIndicatorParameterValueError',
'InvalidBinaryRepresentationError', 'InvalidCachedStatementError',
'InvalidCatalogNameError', 'InvalidCharacterValueForCastError',
'InvalidColumnDefinitionError', 'InvalidColumnReferenceError',
'InvalidCursorDefinitionError', 'InvalidCursorNameError',
'InvalidCursorStateError', 'InvalidDatabaseDefinitionError',
'InvalidDatetimeFormatError', 'InvalidEscapeCharacterError',
'InvalidEscapeOctetError', 'InvalidEscapeSequenceError',
'InvalidForeignKeyError', 'InvalidFunctionDefinitionError',
'InvalidGrantOperationError', 'InvalidGrantorError',
'InvalidIndicatorParameterValueError',
'InvalidLocatorSpecificationError', 'InvalidNameError',
'InvalidObjectDefinitionError', 'InvalidParameterValueError',
'InvalidPasswordError', 'InvalidPreparedStatementDefinitionError',
Expand Down
48 changes: 41 additions & 7 deletions asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
'InterfaceError')


def _is_asyncpg_class(cls):
modname = cls.__module__
return modname == 'asyncpg' or modname.startswith('asyncpg.')


class PostgresMessageMeta(type):
_message_map = {}
_field_map = {
Expand Down Expand Up @@ -40,8 +45,7 @@ def __new__(mcls, name, bases, dct):
for f in mcls._field_map.values():
setattr(cls, f, None)

if (cls.__module__ == 'asyncpg' or
cls.__module__.startswith('asyncpg.')):
if _is_asyncpg_class(cls):
mod = sys.modules[cls.__module__]
if hasattr(mod, name):
raise RuntimeError('exception class redefinition: {}'.format(
Expand Down Expand Up @@ -74,21 +78,51 @@ def __str__(self):
return msg

@classmethod
def new(cls, fields, query=None):
def _get_error_template(cls, fields, query):
errcode = fields.get('C')
mcls = cls.__class__
exccls = mcls.get_message_class_for_sqlstate(errcode)
mapped = {
dct = {
'query': query
}

for k, v in fields.items():
field = mcls._field_map.get(k)
if field:
mapped[field] = v
dct[field] = v

e = exccls(mapped.get('message', ''))
e.__dict__.update(mapped)
return exccls, dct

@classmethod
def new(cls, fields, query=None):
exccls, dct = cls._get_error_template(fields, query)

message = dct.get('message', '')

# PostgreSQL will raise an exception when it detects
# that the result type of the query has changed from
# when the statement was prepared.
#
# The original error is somewhat cryptic and unspecific,
# so we raise a custom subclass that is easier to handle
# and identify.
#
# Note that we specifically do not rely on the error
# message, as it is localizable.
is_icse = (
exccls.__name__ == 'FeatureNotSupportedError' and
_is_asyncpg_class(exccls) and
dct.get('server_source_function') == 'RevalidateCachedQuery'
)

if is_icse:
exceptions = sys.modules[exccls.__module__]
exccls = exceptions.InvalidCachedStatementError
message = ('cached statement plan is invalid due to a database '
'schema or configuration change')

e = exccls(message)
e.__dict__.update(dct)

return e

Expand Down
6 changes: 6 additions & 0 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ def _check_init(self):
if self._closed:
raise exceptions.InterfaceError('pool is closed')

def _drop_statement_cache(self):
# Drop statement cache for all connections in the pool.
for ch in self._holders:
if ch._con is not None:
ch._con._drop_local_statement_cache()

def __await__(self):
return self._async__init__().__await__()

Expand Down
24 changes: 10 additions & 14 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,7 @@ async def fetch(self, *args, timeout=None):
:return: A list of :class:`Record` instances.
"""
self.__check_open()
protocol = self._connection._protocol
data, status, _ = await protocol.bind_execute(
self._state, args, '', 0, True, timeout)
self._last_status = status
data = await self.__bind_execute(args, 0, timeout)
return data

async def fetchval(self, *args, column=0, timeout=None):
Expand All @@ -174,11 +170,7 @@ async def fetchval(self, *args, column=0, timeout=None):
:return: The value of the specified column of the first record.
"""
self.__check_open()
protocol = self._connection._protocol
data, status, _ = await protocol.bind_execute(
self._state, args, '', 1, True, timeout)
self._last_status = status
data = await self.__bind_execute(args, 1, timeout)
if not data:
return None
return data[0][column]
Expand All @@ -192,14 +184,18 @@ async def fetchrow(self, *args, timeout=None):
:return: The first row as a :class:`Record` instance.
"""
data = await self.__bind_execute(args, 1, timeout)
if not data:
return None
return data[0]

async def __bind_execute(self, args, limit, timeout):
self.__check_open()
protocol = self._connection._protocol
data, status, _ = await protocol.bind_execute(
self._state, args, '', 1, True, timeout)
self._state, args, '', limit, True, timeout)
self._last_status = status
if not data:
return None
return data[0]
return data

def __check_open(self):
if self._state.closed:
Expand Down
4 changes: 3 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ cdef class BaseProtocol(CoreProtocol):
return self.settings

def is_in_transaction(self):
return self.xact_status == PQTRANS_INTRANS
# PQTRANS_INTRANS = idle, within transaction block
# PQTRANS_INERROR = idle, within failed transaction
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)

async def prepare(self, stmt_name, query, timeout):
if self.cancel_waiter is not None:
Expand Down
85 changes: 85 additions & 0 deletions tests/test_cache_invalidation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2016-present the ayncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import asyncpg
from asyncpg import _testbase as tb


class TestCacheInvalidation(tb.ConnectedTestCase):
async def test_prepare_cache_invalidation_silent(self):
await self.con.execute('CREATE TABLE tab1(a int, b int)')

try:
await self.con.execute('INSERT INTO tab1 VALUES (1, 2)')
result = await self.con.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, 2))

await self.con.execute(
'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text')

result = await self.con.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, '2'))
finally:
await self.con.execute('DROP TABLE tab1')

async def test_prepare_cache_invalidation_in_transaction(self):
await self.con.execute('CREATE TABLE tab1(a int, b int)')

try:
await self.con.execute('INSERT INTO tab1 VALUES (1, 2)')
result = await self.con.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, 2))

await self.con.execute(
'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text')

with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError,
'cached statement plan is invalid'):
async with self.con.transaction():
result = await self.con.fetchrow('SELECT * FROM tab1')

# This is now OK,
result = await self.con.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, '2'))
finally:
await self.con.execute('DROP TABLE tab1')

async def test_prepare_cache_invalidation_in_pool(self):
pool = await self.create_pool(database='postgres',
min_size=2, max_size=2)

await self.con.execute('CREATE TABLE tab1(a int, b int)')

try:
await self.con.execute('INSERT INTO tab1 VALUES (1, 2)')

con1 = await pool.acquire()
con2 = await pool.acquire()

result = await con1.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, 2))

result = await con2.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, 2))

await self.con.execute(
'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text')

# con1 tries the same plan, will invalidate the cache
# for the entire pool.
result = await con1.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, '2'))

async with con2.transaction():
# This should work, as con1 should have invalidated
# the plan cache.
result = await con2.fetchrow('SELECT * FROM tab1')
self.assertEqual(result, (1, '2'))

finally:
await self.con.execute('DROP TABLE tab1')
await pool.close()
Loading

0 comments on commit 749d857

Please sign in to comment.