From 083bdae6c76a45301ccc9f449652772c607470fb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 6 Aug 2023 07:01:49 -0400 Subject: [PATCH] fix(snowflake): fix timestamp scale inference --- ibis/backends/snowflake/__init__.py | 8 +- ibis/backends/snowflake/datatypes.py | 13 +- .../snowflake/tests/test_datatypes.py | 158 ++++++++++++++++++ ibis/backends/tests/test_client.py | 2 +- 4 files changed, 173 insertions(+), 8 deletions(-) create mode 100644 ibis/backends/snowflake/tests/test_datatypes.py diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index a4bff450720c..695752babc2e 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -424,11 +424,9 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: result = cur.describe(query) for name, type_code, _, _, precision, scale, is_nullable in result: - if precision is not None and scale is not None: - typ = dt.Decimal(precision=precision, scale=scale, nullable=is_nullable) - else: - typ = parse(FIELD_ID_TO_NAME[type_code]).copy(nullable=is_nullable) - yield name, typ + typ_name = FIELD_ID_TO_NAME[type_code] + typ = parse(typ_name, precision=precision, scale=scale) + yield name, typ.copy(nullable=is_nullable) def list_databases(self, like: str | None = None) -> list[str]: with self.begin() as con: diff --git a/ibis/backends/snowflake/datatypes.py b/ibis/backends/snowflake/datatypes.py index 749ebc5cedf7..6699df980d61 100644 --- a/ibis/backends/snowflake/datatypes.py +++ b/ibis/backends/snowflake/datatypes.py @@ -21,7 +21,6 @@ def compiles_nulltype(element, compiler, **kw): _SNOWFLAKE_TYPES = { - "FIXED": dt.int64, "REAL": dt.float64, "TEXT": dt.string, "DATE": dt.date, @@ -38,8 +37,18 @@ def compiles_nulltype(element, compiler, **kw): } -def parse(text: str) -> dt.DataType: +def parse( + text: str, *, precision: int | None = None, scale: int | None = None +) -> dt.DataType: """Parse a Snowflake type into an ibis data type.""" + if text == "FIXED": + if (precision is None and scale is None) or (precision and not scale): + return dt.int64 + else: + return dt.Decimal(precision=precision or 38, scale=scale or 0) + elif text.startswith("TIMESTAMP"): + # timestamp columns have a specified scale, defaulting to 9 + return _SNOWFLAKE_TYPES[text].copy(scale=scale or 9) return _SNOWFLAKE_TYPES[text] diff --git a/ibis/backends/snowflake/tests/test_datatypes.py b/ibis/backends/snowflake/tests/test_datatypes.py new file mode 100644 index 000000000000..d092c1b26e5d --- /dev/null +++ b/ibis/backends/snowflake/tests/test_datatypes.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +from ibis.backends.snowflake.datatypes import parse +from ibis.backends.snowflake.tests.conftest import _get_url +from ibis.util import gen_name + +dtypes = [ + ("FIXED", dt.int64), + ("REAL", dt.float64), + ("TEXT", dt.string), + ("DATE", dt.date), + ("TIMESTAMP", dt.Timestamp(scale=9)), + ("VARIANT", dt.json), + ("TIMESTAMP_LTZ", dt.Timestamp(scale=9)), + ("TIMESTAMP_TZ", dt.Timestamp(timezone="UTC", scale=9)), + ("TIMESTAMP_NTZ", dt.Timestamp(scale=9)), + ("OBJECT", dt.Map(dt.string, dt.json)), + ("ARRAY", dt.Array(dt.json)), + ("BINARY", dt.binary), + ("TIME", dt.time), + ("BOOLEAN", dt.boolean), +] + + +@pytest.mark.parametrize( + ("snowflake_type", "ibis_type"), + [ + param(snowflake_type, ibis_type, id=snowflake_type) + for snowflake_type, ibis_type in dtypes + ], +) +def test_parse(snowflake_type, ibis_type): + assert parse(snowflake_type.upper()) == ibis_type + + +@pytest.fixture(scope="module") +def con(): + return ibis.connect(_get_url()) + + +user_dtypes = [ + ("NUMBER", dt.int64), + ("DECIMAL", dt.int64), + ("NUMERIC", dt.int64), + ("NUMBER(5)", dt.int64), + ("DECIMAL(5, 2)", dt.Decimal(5, 2)), + ("NUMERIC(21, 17)", dt.Decimal(21, 17)), + ("INT", dt.int64), + ("INTEGER", dt.int64), + ("BIGINT", dt.int64), + ("SMALLINT", dt.int64), + ("TINYINT", dt.int64), + ("BYTEINT", dt.int64), + ("FLOAT", dt.float64), + ("FLOAT4", dt.float64), + ("FLOAT8", dt.float64), + ("DOUBLE", dt.float64), + ("DOUBLE PRECISION", dt.float64), + ("REAL", dt.float64), + ("VARCHAR", dt.string), + ("CHAR", dt.string), + ("CHARACTER", dt.string), + ("STRING", dt.string), + ("TEXT", dt.string), + ("BINARY", dt.binary), + ("VARBINARY", dt.binary), + ("BOOLEAN", dt.boolean), + ("DATE", dt.date), + ("TIME", dt.time), + ("VARIANT", dt.json), + ("OBJECT", dt.Map(dt.string, dt.json)), + ("ARRAY", dt.Array(dt.json)), +] + + +@pytest.mark.parametrize( + ("snowflake_type", "ibis_type"), + [ + param(snowflake_type, ibis_type, id=snowflake_type) + for snowflake_type, ibis_type in user_dtypes + ], +) +def test_extract_type_from_table_query(con, snowflake_type, ibis_type): + name = gen_name("test_extract_type_from_table") + with con.begin() as c: + c.exec_driver_sql(f'CREATE TEMP TABLE "{name}" ("a" {snowflake_type})') + + expected_schema = ibis.schema(dict(a=ibis_type)) + + t = con.sql(f'SELECT "a" FROM "{name}"') + assert t.schema() == expected_schema + + +broken_timestamps = pytest.mark.xfail( + raises=AssertionError, + reason=( + "snowflake-sqlalchemy timestamp types are broken and do not preserve scale " + "information" + ), +) + + +@pytest.mark.parametrize( + ("snowflake_type", "ibis_type"), + [ + # what the result SHOULD be + param("DATETIME", dt.Timestamp(scale=9), marks=broken_timestamps), + param("TIMESTAMP", dt.Timestamp(scale=9), marks=broken_timestamps), + param("TIMESTAMP(3)", dt.Timestamp(scale=3), marks=broken_timestamps), + param( + "TIMESTAMP_LTZ", + dt.Timestamp(timezone="UTC", scale=9), + marks=broken_timestamps, + ), + param( + "TIMESTAMP_LTZ(3)", + dt.Timestamp(timezone="UTC", scale=3), + marks=broken_timestamps, + ), + param("TIMESTAMP_NTZ", dt.Timestamp(scale=9), marks=broken_timestamps), + param("TIMESTAMP_NTZ(3)", dt.Timestamp(scale=3), marks=broken_timestamps), + param( + "TIMESTAMP_TZ", + dt.Timestamp(timezone="UTC", scale=9), + marks=broken_timestamps, + ), + param( + "TIMESTAMP_TZ(3)", + dt.Timestamp(timezone="UTC", scale=3), + marks=broken_timestamps, + ), + # what the result ACTUALLY is + ("DATETIME", dt.timestamp), + ("TIMESTAMP", dt.timestamp), + ("TIMESTAMP(3)", dt.timestamp), + ("TIMESTAMP_LTZ", dt.Timestamp(timezone="UTC")), + ("TIMESTAMP_LTZ(3)", dt.Timestamp(timezone="UTC")), + ("TIMESTAMP_NTZ", dt.timestamp), + ("TIMESTAMP_NTZ(3)", dt.timestamp), + ("TIMESTAMP_TZ", dt.Timestamp(timezone="UTC")), + ("TIMESTAMP_TZ(3)", dt.Timestamp(timezone="UTC")), + ], +) +def test_extract_timestamp_from_table_sqlalchemy(con, snowflake_type, ibis_type): + """snowflake-sqlalchemy doesn't preserve timestamp scale information""" + name = gen_name("test_extract_type_from_table") + with con.begin() as c: + c.exec_driver_sql(f'CREATE TEMP TABLE "{name}" ("a" {snowflake_type})') + + expected_schema = ibis.schema(dict(a=ibis_type)) + + t = con.table(name) + assert t.schema() == expected_schema diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 6141bd34c5ba..b21d932a9d4c 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -168,7 +168,7 @@ def test_query_schema(ddl_backend, expr_fn, expected): } -@pytest.mark.notimpl(["datafusion", "snowflake", "polars", "mssql"]) +@pytest.mark.notimpl(["datafusion", "polars", "mssql"]) @pytest.mark.notyet(["sqlite"]) @pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL")