diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 338c0899..4b4d8a23 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -428,6 +428,32 @@ async def _introspect_types(self, typeoids, timeout): return await self.__execute( self._intro_query, (list(typeoids),), 0, timeout) + async def _introspect_type(self, typename, schema): + if ( + schema == 'pg_catalog' + and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP + ): + typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] + rows = await self._execute( + introspection.TYPE_BY_OID, + [typeoid], + limit=0, + timeout=None, + ) + if rows: + typeinfo = rows[0] + else: + typeinfo = None + else: + typeinfo = await self.fetchrow( + introspection.TYPE_BY_NAME, typename, schema) + + if not typeinfo: + raise ValueError( + 'unknown type: {}.{}'.format(schema, typename)) + + return typeinfo + def cursor( self, query, @@ -1110,12 +1136,7 @@ async def set_type_codec(self, typename, *, ``format``. """ self._check_open() - - typeinfo = await self.fetchrow( - introspection.TYPE_BY_NAME, typename, schema) - if not typeinfo: - raise ValueError('unknown type: {}.{}'.format(schema, typename)) - + typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise ValueError( 'cannot use custom codec on non-scalar type {}.{}'.format( @@ -1142,15 +1163,9 @@ async def reset_type_codec(self, typename, *, schema='public'): .. versionadded:: 0.12.0 """ - typeinfo = await self.fetchrow( - introspection.TYPE_BY_NAME, typename, schema) - if not typeinfo: - raise ValueError('unknown type: {}.{}'.format(schema, typename)) - - oid = typeinfo['oid'] - + typeinfo = await self._introspect_type(typename, schema) self._protocol.get_settings().remove_python_codec( - oid, typename, schema) + typeinfo['oid'], typename, schema) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() @@ -1191,13 +1206,7 @@ async def set_builtin_type_codec(self, typename, *, core data type. Added the *format* keyword argument. """ self._check_open() - - typeinfo = await self.fetchrow( - introspection.TYPE_BY_NAME, typename, schema) - if not typeinfo: - raise exceptions.InterfaceError( - 'unknown type: {}.{}'.format(schema, typename)) - + typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( 'cannot alias non-scalar type {}.{}'.format( diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index 201f4341..4854e712 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -147,6 +147,18 @@ ''' +TYPE_BY_OID = '''\ +SELECT + t.oid, + t.typelem AS elemtype, + t.typtype AS kind +FROM + pg_catalog.pg_type AS t +WHERE + t.oid = $1 +''' + + # 'b' for a base type, 'd' for a domain, 'e' for enum. SCALAR_TYPE_KINDS = (b'b', b'd', b'e') diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py index e872e2fa..8b3e06a0 100644 --- a/asyncpg/protocol/__init__.py +++ b/asyncpg/protocol/__init__.py @@ -4,5 +4,6 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +# flake8: NOQA -from .protocol import Protocol, Record, NO_TIMEOUT # NOQA +from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/asyncpg/protocol/pgtypes.pxi b/asyncpg/protocol/pgtypes.pxi index 14db69df..1be40fb2 100644 --- a/asyncpg/protocol/pgtypes.pxi +++ b/asyncpg/protocol/pgtypes.pxi @@ -216,5 +216,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \ BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \ BUILTIN_TYPE_NAME_MAP['timestamptz'] +BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timestamp'] + BUILTIN_TYPE_NAME_MAP['time with timezone'] = \ BUILTIN_TYPE_NAME_MAP['timetz'] + +BUILTIN_TYPE_NAME_MAP['time without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['time'] + +BUILTIN_TYPE_NAME_MAP['char'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character varying'] = \ + BUILTIN_TYPE_NAME_MAP['varchar'] + +BUILTIN_TYPE_NAME_MAP['bit varying'] = \ + BUILTIN_TYPE_NAME_MAP['varbit'] diff --git a/tests/test_codecs.py b/tests/test_codecs.py index abd3c668..9b9c52b3 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1255,6 +1255,39 @@ async def test_custom_codec_on_domain(self): finally: await self.con.execute('DROP DOMAIN custom_codec_t') + async def test_custom_codec_on_stdsql_types(self): + types = [ + 'smallint', + 'int', + 'integer', + 'bigint', + 'decimal', + 'real', + 'double precision', + 'timestamp with timezone', + 'time with timezone', + 'timestamp without timezone', + 'time without timezone', + 'char', + 'character', + 'character varying', + 'bit varying', + 'CHARACTER VARYING' + ] + + for t in types: + with self.subTest(type=t): + try: + await self.con.set_type_codec( + t, + schema='pg_catalog', + encoder=str, + decoder=str, + format='text' + ) + finally: + await self.con.reset_type_codec(t, schema='pg_catalog') + async def test_custom_codec_on_enum(self): """Test encoding/decoding using a custom codec on an enum.""" await self.con.execute('''