Skip to content

Commit

Permalink
feat(postgres): implement _get_schema_using_query
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Mar 30, 2022
1 parent 93cd730 commit f2459eb
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 2 deletions.
5 changes: 4 additions & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import duckdb
import sqlalchemy as sa

if TYPE_CHECKING:
import duckdb

import ibis.expr.schema as sch
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import param

import ibis.expr.datatypes as dt
from ibis.expr.datatypes.duckdb import parse_type
from ibis.backends.duckdb.datatypes import parse_type

EXPECTED_SCHEMA = dict(
a=dt.int64,
Expand Down
94 changes: 94 additions & 0 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import sqlalchemy as sa

import ibis.backends.duckdb.datatypes as ddb
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend

Expand Down Expand Up @@ -179,3 +182,94 @@ def udf(
name=name,
language=language,
)

def _get_schema_using_query(self, query: str) -> sch.Schema:
raw_name = util.guid()
name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
type_info_sql = f"""\
SELECT
attname,
format_type(atttypid, atttypmod) AS type
FROM pg_attribute
WHERE attrelid = {raw_name!r}::regclass
AND attnum > 0
AND NOT attisdropped
ORDER BY attnum
"""
with self.con.connect() as con:
con.execute(f"CREATE TEMPORARY VIEW {name} AS {query}")
try:
type_info = con.execute(type_info_sql).fetchall()
finally:
con.execute(f"DROP VIEW {name}")
tuples = [(col, _get_type(typestr)) for col, typestr in type_info]
return sch.Schema.from_tuples(tuples)


def _get_type(typestr: str) -> dt.DataType:
try:
return _type_mapping[typestr]
except KeyError:
return ddb.parse_type(typestr)


_type_mapping = {
"boolean": dt.bool,
"boolean[]": dt.Array(dt.bool),
"bytea": dt.binary,
"bytea[]": dt.Array(dt.binary),
"character(1)": dt.string,
"character(1)[]": dt.Array(dt.string),
"bigint": dt.int64,
"bigint[]": dt.Array(dt.int64),
"smallint": dt.int16,
"smallint[]": dt.Array(dt.int16),
"integer": dt.int32,
"integer[]": dt.Array(dt.int32),
"text": dt.string,
"text[]": dt.Array(dt.string),
"json": dt.json,
"json[]": dt.Array(dt.json),
"point": dt.point,
"point[]": dt.Array(dt.point),
"polygon": dt.polygon,
"polygon[]": dt.Array(dt.polygon),
"line": dt.linestring,
"line[]": dt.Array(dt.linestring),
"real": dt.float32,
"real[]": dt.Array(dt.float32),
"double precision": dt.float64,
"double precision[]": dt.Array(dt.float64),
"macaddr8": dt.macaddr,
"macaddr8[]": dt.Array(dt.macaddr),
"macaddr": dt.macaddr,
"macaddr[]": dt.Array(dt.macaddr),
"inet": dt.inet,
"inet[]": dt.Array(dt.inet),
"character": dt.string,
"character[]": dt.Array(dt.string),
"character varying": dt.string,
"character varying[]": dt.Array(dt.string),
"date": dt.date,
"date[]": dt.Array(dt.date),
"time without time zone": dt.time,
"time without time zone[]": dt.Array(dt.time),
"timestamp without time zone": dt.timestamp,
"timestamp without time zone[]": dt.Array(dt.timestamp),
"timestamp with time zone": dt.Timestamp("UTC"),
"timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")),
"interval": dt.interval,
"interval[]": dt.Array(dt.interval),
# NB: this isn"t correct, but we try not to fail
"time with time zone": "time",
"numeric": dt.decimal,
"numeric[]": dt.Array(dt.decimal),
"uuid": dt.uuid,
"uuid[]": dt.Array(dt.uuid),
"jsonb": dt.jsonb,
"jsonb[]": dt.Array(dt.jsonb),
"geometry": dt.geometry,
"geometry[]": dt.Array(dt.geometry),
"geography": dt.geography,
"geography[]": dt.Array(dt.geography),
}
50 changes: 50 additions & 0 deletions ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pandas as pd
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -213,3 +214,52 @@ def test_create_and_drop_table(con, temp_table, params):

with pytest.raises(sa.exc.NoSuchTableError):
con.table(temp_table, **params)


@pytest.mark.parametrize(
("pg_type", "expected_type"),
[
param(pg_type, ibis_type, id=pg_type.lower())
for (pg_type, ibis_type) in [
("boolean", dt.boolean),
("bytea", dt.binary),
("char", dt.string),
("bigint", dt.int64),
("smallint", dt.int16),
("integer", dt.int32),
("text", dt.string),
("json", dt.json),
("point", dt.point),
("polygon", dt.polygon),
("line", dt.linestring),
("real", dt.float32),
("double precision", dt.float64),
("macaddr", dt.macaddr),
("macaddr8", dt.macaddr),
("inet", dt.inet),
("character", dt.string),
("character varying", dt.string),
("date", dt.date),
("time", dt.time),
("time without time zone", dt.time),
("timestamp without time zone", dt.timestamp),
("timestamp with time zone", dt.Timestamp("UTC")),
("interval", dt.interval),
("numeric", dt.decimal),
("numeric(3, 2)", dt.Decimal(3, 2)),
("uuid", dt.uuid),
("jsonb", dt.jsonb),
("geometry", dt.geometry),
("geography", dt.geography),
]
],
)
def test_get_schema_from_query(con, pg_type, expected_type):
raw_name = ibis.util.guid()
name = con.con.dialect.identifier_preparer.quote_identifier(raw_name)
con.raw_sql(f"CREATE TEMPORARY TABLE {name} (x {pg_type}, y {pg_type}[])")
expected_schema = ibis.schema(
dict(x=expected_type, y=dt.Array(expected_type))
)
result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}")
assert result_schema == expected_schema

0 comments on commit f2459eb

Please sign in to comment.