Skip to content

Commit

Permalink
fix(snowflake): fix timestamp scale inference
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Aug 8, 2023
1 parent 22ceba7 commit 083bdae
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 8 deletions.
8 changes: 3 additions & 5 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,9 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
result = cur.describe(query)

for name, type_code, _, _, precision, scale, is_nullable in result:
if precision is not None and scale is not None:
typ = dt.Decimal(precision=precision, scale=scale, nullable=is_nullable)
else:
typ = parse(FIELD_ID_TO_NAME[type_code]).copy(nullable=is_nullable)
yield name, typ
typ_name = FIELD_ID_TO_NAME[type_code]
typ = parse(typ_name, precision=precision, scale=scale)
yield name, typ.copy(nullable=is_nullable)

def list_databases(self, like: str | None = None) -> list[str]:
with self.begin() as con:
Expand Down
13 changes: 11 additions & 2 deletions ibis/backends/snowflake/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def compiles_nulltype(element, compiler, **kw):


_SNOWFLAKE_TYPES = {
"FIXED": dt.int64,
"REAL": dt.float64,
"TEXT": dt.string,
"DATE": dt.date,
Expand All @@ -38,8 +37,18 @@ def compiles_nulltype(element, compiler, **kw):
}


def parse(text: str) -> dt.DataType:
def parse(
text: str, *, precision: int | None = None, scale: int | None = None
) -> dt.DataType:
"""Parse a Snowflake type into an ibis data type."""
if text == "FIXED":
if (precision is None and scale is None) or (precision and not scale):
return dt.int64
else:
return dt.Decimal(precision=precision or 38, scale=scale or 0)
elif text.startswith("TIMESTAMP"):
# timestamp columns have a specified scale, defaulting to 9
return _SNOWFLAKE_TYPES[text].copy(scale=scale or 9)
return _SNOWFLAKE_TYPES[text]


Expand Down
158 changes: 158 additions & 0 deletions ibis/backends/snowflake/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis.backends.snowflake.datatypes import parse
from ibis.backends.snowflake.tests.conftest import _get_url
from ibis.util import gen_name

dtypes = [
("FIXED", dt.int64),
("REAL", dt.float64),
("TEXT", dt.string),
("DATE", dt.date),
("TIMESTAMP", dt.Timestamp(scale=9)),
("VARIANT", dt.json),
("TIMESTAMP_LTZ", dt.Timestamp(scale=9)),
("TIMESTAMP_TZ", dt.Timestamp(timezone="UTC", scale=9)),
("TIMESTAMP_NTZ", dt.Timestamp(scale=9)),
("OBJECT", dt.Map(dt.string, dt.json)),
("ARRAY", dt.Array(dt.json)),
("BINARY", dt.binary),
("TIME", dt.time),
("BOOLEAN", dt.boolean),
]


@pytest.mark.parametrize(
("snowflake_type", "ibis_type"),
[
param(snowflake_type, ibis_type, id=snowflake_type)
for snowflake_type, ibis_type in dtypes
],
)
def test_parse(snowflake_type, ibis_type):
assert parse(snowflake_type.upper()) == ibis_type


@pytest.fixture(scope="module")
def con():
return ibis.connect(_get_url())


user_dtypes = [
("NUMBER", dt.int64),
("DECIMAL", dt.int64),
("NUMERIC", dt.int64),
("NUMBER(5)", dt.int64),
("DECIMAL(5, 2)", dt.Decimal(5, 2)),
("NUMERIC(21, 17)", dt.Decimal(21, 17)),
("INT", dt.int64),
("INTEGER", dt.int64),
("BIGINT", dt.int64),
("SMALLINT", dt.int64),
("TINYINT", dt.int64),
("BYTEINT", dt.int64),
("FLOAT", dt.float64),
("FLOAT4", dt.float64),
("FLOAT8", dt.float64),
("DOUBLE", dt.float64),
("DOUBLE PRECISION", dt.float64),
("REAL", dt.float64),
("VARCHAR", dt.string),
("CHAR", dt.string),
("CHARACTER", dt.string),
("STRING", dt.string),
("TEXT", dt.string),
("BINARY", dt.binary),
("VARBINARY", dt.binary),
("BOOLEAN", dt.boolean),
("DATE", dt.date),
("TIME", dt.time),
("VARIANT", dt.json),
("OBJECT", dt.Map(dt.string, dt.json)),
("ARRAY", dt.Array(dt.json)),
]


@pytest.mark.parametrize(
("snowflake_type", "ibis_type"),
[
param(snowflake_type, ibis_type, id=snowflake_type)
for snowflake_type, ibis_type in user_dtypes
],
)
def test_extract_type_from_table_query(con, snowflake_type, ibis_type):
name = gen_name("test_extract_type_from_table")
with con.begin() as c:
c.exec_driver_sql(f'CREATE TEMP TABLE "{name}" ("a" {snowflake_type})')

expected_schema = ibis.schema(dict(a=ibis_type))

t = con.sql(f'SELECT "a" FROM "{name}"')
assert t.schema() == expected_schema


broken_timestamps = pytest.mark.xfail(
raises=AssertionError,
reason=(
"snowflake-sqlalchemy timestamp types are broken and do not preserve scale "
"information"
),
)


@pytest.mark.parametrize(
("snowflake_type", "ibis_type"),
[
# what the result SHOULD be
param("DATETIME", dt.Timestamp(scale=9), marks=broken_timestamps),
param("TIMESTAMP", dt.Timestamp(scale=9), marks=broken_timestamps),
param("TIMESTAMP(3)", dt.Timestamp(scale=3), marks=broken_timestamps),
param(
"TIMESTAMP_LTZ",
dt.Timestamp(timezone="UTC", scale=9),
marks=broken_timestamps,
),
param(
"TIMESTAMP_LTZ(3)",
dt.Timestamp(timezone="UTC", scale=3),
marks=broken_timestamps,
),
param("TIMESTAMP_NTZ", dt.Timestamp(scale=9), marks=broken_timestamps),
param("TIMESTAMP_NTZ(3)", dt.Timestamp(scale=3), marks=broken_timestamps),
param(
"TIMESTAMP_TZ",
dt.Timestamp(timezone="UTC", scale=9),
marks=broken_timestamps,
),
param(
"TIMESTAMP_TZ(3)",
dt.Timestamp(timezone="UTC", scale=3),
marks=broken_timestamps,
),
# what the result ACTUALLY is
("DATETIME", dt.timestamp),
("TIMESTAMP", dt.timestamp),
("TIMESTAMP(3)", dt.timestamp),
("TIMESTAMP_LTZ", dt.Timestamp(timezone="UTC")),
("TIMESTAMP_LTZ(3)", dt.Timestamp(timezone="UTC")),
("TIMESTAMP_NTZ", dt.timestamp),
("TIMESTAMP_NTZ(3)", dt.timestamp),
("TIMESTAMP_TZ", dt.Timestamp(timezone="UTC")),
("TIMESTAMP_TZ(3)", dt.Timestamp(timezone="UTC")),
],
)
def test_extract_timestamp_from_table_sqlalchemy(con, snowflake_type, ibis_type):
"""snowflake-sqlalchemy doesn't preserve timestamp scale information"""
name = gen_name("test_extract_type_from_table")
with con.begin() as c:
c.exec_driver_sql(f'CREATE TEMP TABLE "{name}" ("a" {snowflake_type})')

expected_schema = ibis.schema(dict(a=ibis_type))

t = con.table(name)
assert t.schema() == expected_schema
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_query_schema(ddl_backend, expr_fn, expected):
}


@pytest.mark.notimpl(["datafusion", "snowflake", "polars", "mssql"])
@pytest.mark.notimpl(["datafusion", "polars", "mssql"])
@pytest.mark.notyet(["sqlite"])
@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError)
@pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL")
Expand Down

0 comments on commit 083bdae

Please sign in to comment.