diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4b4d8a23..b7266471 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -342,13 +342,16 @@ async def _get_statement( *, named: bool=False, use_cache: bool=True, + ignore_custom_codec=False, record_class=None ): if record_class is None: record_class = self._protocol.get_record_class() if use_cache: - statement = self._stmt_cache.get((query, record_class)) + statement = self._stmt_cache.get( + (query, record_class, ignore_custom_codec) + ) if statement is not None: return statement @@ -371,6 +374,7 @@ async def _get_statement( query, timeout, record_class=record_class, + ignore_custom_codec=ignore_custom_codec, ) need_reprepare = False types_with_missing_codecs = statement._init_types() @@ -415,7 +419,8 @@ async def _get_statement( ) if use_cache: - self._stmt_cache.put((query, record_class), statement) + self._stmt_cache.put( + (query, record_class, ignore_custom_codec), statement) # If we've just created a new statement object, check if there # are any statements for GC. @@ -426,7 +431,12 @@ async def _get_statement( async def _introspect_types(self, typeoids, timeout): return await self.__execute( - self._intro_query, (list(typeoids),), 0, timeout) + self._intro_query, + (list(typeoids),), + 0, + timeout, + ignore_custom_codec=True, + ) async def _introspect_type(self, typename, schema): if ( @@ -439,20 +449,22 @@ async def _introspect_type(self, typename, schema): [typeoid], limit=0, timeout=None, + ignore_custom_codec=True, ) - if rows: - typeinfo = rows[0] - else: - typeinfo = None else: - typeinfo = await self.fetchrow( - introspection.TYPE_BY_NAME, typename, schema) + rows = await self._execute( + introspection.TYPE_BY_NAME, + [typename, schema], + limit=1, + timeout=None, + ignore_custom_codec=True, + ) - if not typeinfo: + if not rows: raise ValueError( 'unknown type: {}.{}'.format(schema, typename)) - return typeinfo + return rows[0] def cursor( self, @@ -1325,7 +1337,9 @@ def _mark_stmts_as_closed(self): def _maybe_gc_stmt(self, stmt): if ( stmt.refs == 0 - and not self._stmt_cache.has((stmt.query, stmt.record_class)) + and not self._stmt_cache.has( + (stmt.query, stmt.record_class, stmt.ignore_custom_codec) + ) ): # If low-level `stmt` isn't referenced from any high-level # `PreparedStatement` object and is not in the `_stmt_cache`: @@ -1589,6 +1603,7 @@ async def _execute( timeout, *, return_status=False, + ignore_custom_codec=False, record_class=None ): with self._stmt_exclusive_section: @@ -1599,6 +1614,7 @@ async def _execute( timeout, return_status=return_status, record_class=record_class, + ignore_custom_codec=ignore_custom_codec, ) return result @@ -1610,6 +1626,7 @@ async def __execute( timeout, *, return_status=False, + ignore_custom_codec=False, record_class=None ): executor = lambda stmt, timeout: self._protocol.bind_execute( @@ -1620,6 +1637,7 @@ async def __execute( executor, timeout, record_class=record_class, + ignore_custom_codec=ignore_custom_codec, ) async def _executemany(self, query, args, timeout): @@ -1637,6 +1655,7 @@ async def _do_execute( timeout, retry=True, *, + ignore_custom_codec=False, record_class=None ): if timeout is None: @@ -1644,6 +1663,7 @@ async def _do_execute( query, None, record_class=record_class, + ignore_custom_codec=ignore_custom_codec, ) else: before = time.monotonic() @@ -1651,6 +1671,7 @@ async def _do_execute( query, timeout, record_class=record_class, + ignore_custom_codec=ignore_custom_codec, ) after = time.monotonic() timeout -= after - before diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd index be1f0a3f..e8136f7b 100644 --- a/asyncpg/protocol/codecs/base.pxd +++ b/asyncpg/protocol/codecs/base.pxd @@ -166,5 +166,6 @@ cdef class DataCodecConfig: dict _derived_type_codecs dict _custom_type_codecs - cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format) + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=*) cdef inline Codec get_any_local_codec(self, uint32_t oid) diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index 238fa280..1c930cd0 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -692,18 +692,20 @@ cdef class DataCodecConfig: return codec - cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format): + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=False): cdef Codec codec - codec = self.get_any_local_codec(oid) - if codec is not None: - if codec.format != format: - # The codec for this OID has been overridden by - # set_{builtin}_type_codec with a different format. - # We must respect that and not return a core codec. - return None - else: - return codec + if not ignore_custom_codec: + codec = self.get_any_local_codec(oid) + if codec is not None: + if codec.format != format: + # The codec for this OID has been overridden by + # set_{builtin}_type_codec with a different format. + # We must respect that and not return a core codec. + return None + else: + return codec codec = get_core_codec(oid, format) if codec is not None: diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 90944c1a..4427bfdc 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -12,6 +12,7 @@ cdef class PreparedStatementState: readonly bint closed readonly int refs readonly type record_class + readonly bint ignore_custom_codec list row_desc diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 60094be6..fd9f5a26 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -16,7 +16,8 @@ cdef class PreparedStatementState: str name, str query, BaseProtocol protocol, - type record_class + type record_class, + bint ignore_custom_codec ): self.name = name self.query = query @@ -28,6 +29,7 @@ cdef class PreparedStatementState: self.closed = False self.refs = 0 self.record_class = record_class + self.ignore_custom_codec = ignore_custom_codec def _get_parameters(self): cdef Codec codec @@ -205,7 +207,8 @@ cdef class PreparedStatementState: cols_mapping[col_name] = i cols_names.append(col_name) oid = row[3] - codec = self.settings.get_data_codec(oid) + codec = self.settings.get_data_codec( + oid, ignore_custom_codec=self.ignore_custom_codec) if codec is None or not codec.has_decoder(): raise exceptions.InternalClientError( 'no decoder for OID {}'.format(oid)) @@ -230,7 +233,8 @@ cdef class PreparedStatementState: for i from 0 <= i < self.args_num: p_oid = self.parameters_desc[i] - codec = self.settings.get_data_codec(p_oid) + codec = self.settings.get_data_codec( + p_oid, ignore_custom_codec=self.ignore_custom_codec) if codec is None or not codec.has_encoder(): raise exceptions.InternalClientError( 'no encoder for OID {}'.format(p_oid)) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 4f7ce675..a6d9ad5d 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -145,6 +145,7 @@ cdef class BaseProtocol(CoreProtocol): async def prepare(self, stmt_name, query, timeout, *, PreparedStatementState state=None, + ignore_custom_codec=False, record_class): if self.cancel_waiter is not None: await self.cancel_waiter @@ -161,7 +162,7 @@ cdef class BaseProtocol(CoreProtocol): self.last_query = query if state is None: state = PreparedStatementState( - stmt_name, query, self, record_class) + stmt_name, query, self, record_class, ignore_custom_codec) self.statement = state except Exception as ex: waiter.set_exception(ex) diff --git a/asyncpg/protocol/settings.pxd b/asyncpg/protocol/settings.pxd index 44b673c2..41131cdc 100644 --- a/asyncpg/protocol/settings.pxd +++ b/asyncpg/protocol/settings.pxd @@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext): cpdef inline set_builtin_type_codec( self, typeoid, typename, typeschema, typekind, alias_to, format) cpdef inline Codec get_data_codec( - self, uint32_t oid, ServerDataFormat format=*) + self, uint32_t oid, ServerDataFormat format=*, + bint ignore_custom_codec=*) diff --git a/asyncpg/protocol/settings.pyx b/asyncpg/protocol/settings.pyx index 2ea72169..6121fce4 100644 --- a/asyncpg/protocol/settings.pyx +++ b/asyncpg/protocol/settings.pyx @@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext): typekind, alias_to, _format) cpdef inline Codec get_data_codec(self, uint32_t oid, - ServerDataFormat format=PG_FORMAT_ANY): + ServerDataFormat format=PG_FORMAT_ANY, + bint ignore_custom_codec=False): if format == PG_FORMAT_ANY: - codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY) + codec = self._data_codecs.get_codec( + oid, PG_FORMAT_BINARY, ignore_custom_codec) if codec is None: - codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT) + codec = self._data_codecs.get_codec( + oid, PG_FORMAT_TEXT, ignore_custom_codec) return codec else: - return self._data_codecs.get_codec(oid, format) + return self._data_codecs.get_codec( + oid, format, ignore_custom_codec) def __getattr__(self, name): if not name.startswith('_'): diff --git a/tests/test_introspection.py b/tests/test_introspection.py index eb3258f9..7de4236f 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -43,6 +43,20 @@ def tearDownClass(cls): super().tearDownClass() + def setUp(self): + super().setUp() + self.loop.run_until_complete(self._add_custom_codec(self.con)) + + async def _add_custom_codec(self, conn): + # mess up with the codec - builtin introspection shouldn't be affected + await conn.set_type_codec( + "oid", + schema="pg_catalog", + encoder=lambda value: None, + decoder=lambda value: None, + format="text", + ) + @tb.with_connection_options(database='asyncpg_intro_test') async def test_introspection_on_large_db(self): await self.con.execute( @@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self): # query would cause introspection to retry. slow_intro_conn = await self.connect( connection_class=SlowIntrospectionConnection) + await self._add_custom_codec(slow_intro_conn) try: await self.con.execute(''' CREATE DOMAIN intro_1_t AS int;