From f2459ebf3995c968c59cd16a7a684297d08ce4f2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 29 Mar 2022 07:27:13 -0400 Subject: [PATCH] feat(postgres): implement _get_schema_using_query --- ibis/backends/duckdb/__init__.py | 5 +- .../duckdb/datatypes.py} | 0 .../duckdb/tests/test_datatypes.py} | 2 +- ibis/backends/postgres/__init__.py | 94 +++++++++++++++++++ ibis/backends/postgres/tests/test_client.py | 50 ++++++++++ 5 files changed, 149 insertions(+), 2 deletions(-) rename ibis/{expr/datatypes/duckdb.py => backends/duckdb/datatypes.py} (100%) rename ibis/{tests/expr/test_duckdb_datatype_parser.py => backends/duckdb/tests/test_datatypes.py} (98%) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 8d12e2418a7e..d256bc420fda 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -3,10 +3,13 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING -import duckdb import sqlalchemy as sa +if TYPE_CHECKING: + import duckdb + import ibis.expr.schema as sch from ibis.backends.base.sql.alchemy import BaseAlchemyBackend diff --git a/ibis/expr/datatypes/duckdb.py b/ibis/backends/duckdb/datatypes.py similarity index 100% rename from ibis/expr/datatypes/duckdb.py rename to ibis/backends/duckdb/datatypes.py diff --git a/ibis/tests/expr/test_duckdb_datatype_parser.py b/ibis/backends/duckdb/tests/test_datatypes.py similarity index 98% rename from ibis/tests/expr/test_duckdb_datatype_parser.py rename to ibis/backends/duckdb/tests/test_datatypes.py index 12d1d40435ca..21a87fef609d 100644 --- a/ibis/tests/expr/test_duckdb_datatype_parser.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -2,7 +2,7 @@ from pytest import param import ibis.expr.datatypes as dt -from ibis.expr.datatypes.duckdb import parse_type +from ibis.backends.duckdb.datatypes import parse_type EXPECTED_SCHEMA = dict( a=dt.int64, diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 72ebd5838561..d2057473ed1b 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -7,6 +7,9 @@ import sqlalchemy as sa +import ibis.backends.duckdb.datatypes as ddb +import ibis.expr.datatypes as dt +import ibis.expr.schema as sch from ibis import util from ibis.backends.base.sql.alchemy import BaseAlchemyBackend @@ -179,3 +182,94 @@ def udf( name=name, language=language, ) + + def _get_schema_using_query(self, query: str) -> sch.Schema: + raw_name = util.guid() + name = self.con.dialect.identifier_preparer.quote_identifier(raw_name) + type_info_sql = f"""\ +SELECT + attname, + format_type(atttypid, atttypmod) AS type +FROM pg_attribute +WHERE attrelid = {raw_name!r}::regclass + AND attnum > 0 + AND NOT attisdropped +ORDER BY attnum +""" + with self.con.connect() as con: + con.execute(f"CREATE TEMPORARY VIEW {name} AS {query}") + try: + type_info = con.execute(type_info_sql).fetchall() + finally: + con.execute(f"DROP VIEW {name}") + tuples = [(col, _get_type(typestr)) for col, typestr in type_info] + return sch.Schema.from_tuples(tuples) + + +def _get_type(typestr: str) -> dt.DataType: + try: + return _type_mapping[typestr] + except KeyError: + return ddb.parse_type(typestr) + + +_type_mapping = { + "boolean": dt.bool, + "boolean[]": dt.Array(dt.bool), + "bytea": dt.binary, + "bytea[]": dt.Array(dt.binary), + "character(1)": dt.string, + "character(1)[]": dt.Array(dt.string), + "bigint": dt.int64, + "bigint[]": dt.Array(dt.int64), + "smallint": dt.int16, + "smallint[]": dt.Array(dt.int16), + "integer": dt.int32, + "integer[]": dt.Array(dt.int32), + "text": dt.string, + "text[]": dt.Array(dt.string), + "json": dt.json, + "json[]": dt.Array(dt.json), + "point": dt.point, + "point[]": dt.Array(dt.point), + "polygon": dt.polygon, + "polygon[]": dt.Array(dt.polygon), + "line": dt.linestring, + "line[]": dt.Array(dt.linestring), + "real": dt.float32, + "real[]": dt.Array(dt.float32), + "double precision": dt.float64, + "double precision[]": dt.Array(dt.float64), + "macaddr8": dt.macaddr, + "macaddr8[]": dt.Array(dt.macaddr), + "macaddr": dt.macaddr, + "macaddr[]": dt.Array(dt.macaddr), + "inet": dt.inet, + "inet[]": dt.Array(dt.inet), + "character": dt.string, + "character[]": dt.Array(dt.string), + "character varying": dt.string, + "character varying[]": dt.Array(dt.string), + "date": dt.date, + "date[]": dt.Array(dt.date), + "time without time zone": dt.time, + "time without time zone[]": dt.Array(dt.time), + "timestamp without time zone": dt.timestamp, + "timestamp without time zone[]": dt.Array(dt.timestamp), + "timestamp with time zone": dt.Timestamp("UTC"), + "timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")), + "interval": dt.interval, + "interval[]": dt.Array(dt.interval), + # NB: this isn"t correct, but we try not to fail + "time with time zone": "time", + "numeric": dt.decimal, + "numeric[]": dt.Array(dt.decimal), + "uuid": dt.uuid, + "uuid[]": dt.Array(dt.uuid), + "jsonb": dt.jsonb, + "jsonb[]": dt.Array(dt.jsonb), + "geometry": dt.geometry, + "geometry[]": dt.Array(dt.geometry), + "geography": dt.geography, + "geography[]": dt.Array(dt.geography), +} diff --git a/ibis/backends/postgres/tests/test_client.py b/ibis/backends/postgres/tests/test_client.py index f097d7113253..6d082d062eff 100644 --- a/ibis/backends/postgres/tests/test_client.py +++ b/ibis/backends/postgres/tests/test_client.py @@ -18,6 +18,7 @@ import numpy as np import pandas as pd import pytest +from pytest import param import ibis import ibis.expr.datatypes as dt @@ -213,3 +214,52 @@ def test_create_and_drop_table(con, temp_table, params): with pytest.raises(sa.exc.NoSuchTableError): con.table(temp_table, **params) + + +@pytest.mark.parametrize( + ("pg_type", "expected_type"), + [ + param(pg_type, ibis_type, id=pg_type.lower()) + for (pg_type, ibis_type) in [ + ("boolean", dt.boolean), + ("bytea", dt.binary), + ("char", dt.string), + ("bigint", dt.int64), + ("smallint", dt.int16), + ("integer", dt.int32), + ("text", dt.string), + ("json", dt.json), + ("point", dt.point), + ("polygon", dt.polygon), + ("line", dt.linestring), + ("real", dt.float32), + ("double precision", dt.float64), + ("macaddr", dt.macaddr), + ("macaddr8", dt.macaddr), + ("inet", dt.inet), + ("character", dt.string), + ("character varying", dt.string), + ("date", dt.date), + ("time", dt.time), + ("time without time zone", dt.time), + ("timestamp without time zone", dt.timestamp), + ("timestamp with time zone", dt.Timestamp("UTC")), + ("interval", dt.interval), + ("numeric", dt.decimal), + ("numeric(3, 2)", dt.Decimal(3, 2)), + ("uuid", dt.uuid), + ("jsonb", dt.jsonb), + ("geometry", dt.geometry), + ("geography", dt.geography), + ] + ], +) +def test_get_schema_from_query(con, pg_type, expected_type): + raw_name = ibis.util.guid() + name = con.con.dialect.identifier_preparer.quote_identifier(raw_name) + con.raw_sql(f"CREATE TEMPORARY TABLE {name} (x {pg_type}, y {pg_type}[])") + expected_schema = ibis.schema( + dict(x=expected_type, y=dt.Array(expected_type)) + ) + result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}") + assert result_schema == expected_schema