From 066c1582cab394896625b93bf8b05d07c2a321db Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 10 Mar 2023 07:24:19 -0500 Subject: [PATCH] fix(duckdb): support casting to unsigned integer types --- ibis/backends/duckdb/compiler.py | 20 +++++++++++++++---- .../test_cast_uints/uint16/out.sql | 3 +++ .../test_cast_uints/uint32/out.sql | 3 +++ .../test_cast_uints/uint64/out.sql | 3 +++ .../test_cast_uints/uint8/out.sql | 3 +++ ibis/backends/duckdb/tests/test_datatypes.py | 17 ++++++++++++++++ poetry.lock | 6 +++--- requirements.txt | 2 +- 8 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql create mode 100644 ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql create mode 100644 ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql create mode 100644 ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 5bc0d10f0c27..bae7ed49d35d 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -20,12 +20,24 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator): _dialect_name = "duckdb" -@compiles(sat.UInt64, "duckdb") -@compiles(sat.UInt32, "duckdb") -@compiles(sat.UInt16, "duckdb") @compiles(sat.UInt8, "duckdb") +def compile_uint8(element, compiler, **kw): + return "UTINYINT" + + +@compiles(sat.UInt16, "duckdb") +def compile_uint16(element, compiler, **kw): + return "USMALLINT" + + +@compiles(sat.UInt32, "duckdb") +def compile_uint32(element, compiler, **kw): + return "UINTEGER" + + +@compiles(sat.UInt64, "duckdb") def compile_uint(element, compiler, **kw): - return element.__class__.__name__.upper() + return "UBIGINT" @compiles(sat.ArrayType, "duckdb") diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql new file mode 100644 index 000000000000..abb420080b20 --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql @@ -0,0 +1,3 @@ +SELECT + CAST(t0.a AS USMALLINT) AS "Cast(a, uint16)" +FROM t AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql new file mode 100644 index 000000000000..b2ec0d726884 --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql @@ -0,0 +1,3 @@ +SELECT + CAST(t0.a AS UINTEGER) AS "Cast(a, uint32)" +FROM t AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql new file mode 100644 index 000000000000..6cefd3bb478b --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql @@ -0,0 +1,3 @@ +SELECT + CAST(t0.a AS UBIGINT) AS "Cast(a, uint64)" +FROM t AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql new file mode 100644 index 000000000000..dae9dbdc41cb --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql @@ -0,0 +1,3 @@ +SELECT + CAST(t0.a AS UTINYINT) AS "Cast(a, uint8)" +FROM t AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/test_datatypes.py b/ibis/backends/duckdb/tests/test_datatypes.py index 8bf4e9331309..dd90c10e789d 100644 --- a/ibis/backends/duckdb/tests/test_datatypes.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -1,4 +1,6 @@ import pytest +import sqlglot as sg +from packaging.version import parse as vparse from pytest import param import ibis.expr.datatypes as dt @@ -78,3 +80,18 @@ def test_parser(typ, expected): ty = parse(typ) assert ty == expected + + +@pytest.mark.parametrize("uint_type", ["uint8", "uint16", "uint32", "uint64"]) +@pytest.mark.xfail( + vparse(sg.__version__) < vparse("11.3.4"), + raises=sg.ParseError, + reason="sqlglot version doesn't support duckdb unsigned integer types", +) +def test_cast_uints(uint_type, snapshot): + import ibis + + t = ibis.table(dict(a="int8"), name="t") + snapshot.assert_match( + str(ibis.to_sql(t.a.cast(uint_type), dialect="duckdb")), "out.sql" + ) diff --git a/poetry.lock b/poetry.lock index f65a0c8b5439..b55d5e42233e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5067,14 +5067,14 @@ sqlalchemy = ">=1.0.0" [[package]] name = "sqlglot" -version = "11.3.0" +version = "11.3.6" description = "An easily customizable SQL parser and transpiler" category = "main" optional = false python-versions = "*" files = [ - {file = "sqlglot-11.3.0-py3-none-any.whl", hash = "sha256:a95d22c4d5de61ba3bf96414f5c000525b4345e337d4b73c9736bfd421e354e7"}, - {file = "sqlglot-11.3.0.tar.gz", hash = "sha256:5bd317d8d08c77d7459a3043fe8c4fda942d64054461daa424e717e55642892e"}, + {file = "sqlglot-11.3.6-py3-none-any.whl", hash = "sha256:c16e8889faa09caa43943fa16c0735b8dddcf97f3700b9b1e681227375357aa8"}, + {file = "sqlglot-11.3.6.tar.gz", hash = "sha256:70dcaa528c9c99ef7fc328cfe0dbd7c1ae25843ea3cd5c61bd5c9fda57d1b467"}, ] [package.extras] diff --git a/requirements.txt b/requirements.txt index 5aee2682e010..820deca1d84e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -197,7 +197,7 @@ sortedcontainers==2.4.0 ; python_version >= "3.8" and python_version < "4.0" soupsieve==2.4 ; python_version >= "3.8" and python_version < "4" sqlalchemy-views==0.3.2 ; python_version >= "3.8" and python_version < "4.0" sqlalchemy==1.4.46 ; python_version >= "3.8" and python_version < "4.0" -sqlglot==11.3.0 ; python_version >= "3.8" and python_version < "4.0" +sqlglot==11.3.6 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.2 ; python_version >= "3.8" and python_version < "4.0" termcolor==2.2.0 ; python_version >= "3.8" and python_version < "4.0" thrift-sasl==0.4.3 ; python_version >= "3.8" and python_version < "4.0"