diff --git a/ibis/backends/base/sql/alchemy/datatypes.py b/ibis/backends/base/sql/alchemy/datatypes.py index 214a80956259..1608faee29aa 100644 --- a/ibis/backends/base/sql/alchemy/datatypes.py +++ b/ibis/backends/base/sql/alchemy/datatypes.py @@ -41,6 +41,19 @@ def compiles_array(element, compiler, **kw): return f"ARRAY({compiler.process(element.value_type, **kw)})" +@compiles(sat.FLOAT, "duckdb") +def compiles_float(element, compiler, **kw): + precision = element.precision + if precision is None or 1 <= precision <= 24: + return "FLOAT" + elif 24 < precision <= 53: + return "DOUBLE" + else: + raise ValueError( + "FLOAT precision must be between 1 and 53 inclusive, or `None`" + ) + + class StructType(sat.UserDefinedType): cache_ok = True diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float32/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float32/out.sql new file mode 100644 index 000000000000..68251010b09f --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float32/out.sql @@ -0,0 +1,2 @@ +SELECT + CAST('1.0' AS REAL) AS "Cast('1.0', float32)" \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float64/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float64/out.sql new file mode 100644 index 000000000000..cc84473347bb --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_to_floating_point_type/float64/out.sql @@ -0,0 +1,2 @@ +SELECT + CAST('1.0' AS DOUBLE) AS "Cast('1.0', float64)" \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_geospatial/test_geospatial_dwithin/out.sql b/ibis/backends/duckdb/tests/snapshots/test_geospatial/test_geospatial_dwithin/out.sql index 677936b16a31..e68c65813913 100644 --- a/ibis/backends/duckdb/tests/snapshots/test_geospatial/test_geospatial_dwithin/out.sql +++ b/ibis/backends/duckdb/tests/snapshots/test_geospatial/test_geospatial_dwithin/out.sql @@ -1,3 +1,3 @@ SELECT - ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS REAL(53))) AS tmp + ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS DOUBLE)) AS tmp FROM t AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/test_datatypes.py b/ibis/backends/duckdb/tests/test_datatypes.py index d2e3bfc97c8f..7f0212f21374 100644 --- a/ibis/backends/duckdb/tests/test_datatypes.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -1,11 +1,13 @@ from __future__ import annotations import duckdb_engine +import numpy as np import pytest import sqlalchemy as sa from packaging.version import parse as vparse from pytest import param +import ibis import ibis.backends.base.sql.alchemy.datatypes as sat import ibis.common.exceptions as exc import ibis.expr.datatypes as dt @@ -66,8 +68,6 @@ def test_parser(typ, expected): @pytest.mark.parametrize("uint_type", ["uint8", "uint16", "uint32", "uint64"]) def test_cast_uints(uint_type, snapshot): - import ibis - t = ibis.table(dict(a="int8"), name="t") snapshot.assert_match( str(ibis.to_sql(t.a.cast(uint_type), dialect="duckdb")), "out.sql" @@ -75,8 +75,6 @@ def test_cast_uints(uint_type, snapshot): def test_null_dtype(): - import ibis - con = ibis.connect("duckdb://:memory:") t = ibis.memtable({"a": [None, None]}) @@ -110,10 +108,6 @@ def test_generate_quoted_struct(): reason="mapping from UINTEGER query metadata fixed in 0.9.2", ) def test_read_uint8_from_parquet(tmp_path): - import numpy as np - - import ibis - con = ibis.duckdb.connect() # There is an incorrect mapping in duckdb-engine from UInteger -> UInt8 @@ -129,3 +123,16 @@ def test_read_uint8_from_parquet(tmp_path): t2 = con.read_parquet(parqpath) assert t2.schema() == t.schema() + + +@pytest.mark.parametrize("typ", ["float32", "float64"]) +def test_cast_to_floating_point_type(con, snapshot, typ): + expected = 1.0 + value = ibis.literal(str(expected)) + expr = value.cast(typ) + + result = con.execute(expr) + assert result == expected + + sql = str(ibis.to_sql(expr, dialect="duckdb")) + snapshot.assert_match(sql, "out.sql") diff --git a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql index 8ec16703aeee..eae15c8677d5 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql @@ -4,5 +4,5 @@ FROM main.lineitem AS t0 WHERE t0.l_shipdate >= MAKE_DATE(1994, 1, 1) AND t0.l_shipdate < MAKE_DATE(1995, 1, 1) - AND t0.l_discount BETWEEN CAST(0.05 AS REAL(53)) AND CAST(0.07 AS REAL(53)) + AND t0.l_discount BETWEEN CAST(0.05 AS DOUBLE) AND CAST(0.07 AS DOUBLE) AND t0.l_quantity < CAST(24 AS TINYINT) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h11/test_tpc_h11/duckdb/h11.sql b/ibis/backends/tests/tpch/snapshots/test_h11/test_tpc_h11/duckdb/h11.sql index c19193299b09..edbba7a0223d 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h11/test_tpc_h11/duckdb/h11.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h11/test_tpc_h11/duckdb/h11.sql @@ -35,7 +35,7 @@ FROM ( WHERE t4.n_name = 'GERMANY' ) AS anon_1 - ) * CAST(0.0001 AS REAL(53)) + ) * CAST(0.0001 AS DOUBLE) ) AS t1 ORDER BY t1.value DESC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql index 4e9c6e9f6da4..e0adc83afc3b 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql @@ -1,5 +1,5 @@ SELECT - SUM(t0.l_extendedprice) / CAST(7.0 AS REAL(53)) AS avg_yearly + SUM(t0.l_extendedprice) / CAST(7.0 AS DOUBLE) AS avg_yearly FROM main.lineitem AS t0 JOIN main.part AS t1 ON t1.p_partkey = t0.l_partkey @@ -12,4 +12,4 @@ WHERE FROM main.lineitem AS t0 WHERE t0.l_partkey = t1.p_partkey - ) * CAST(0.2 AS REAL(53)) \ No newline at end of file + ) * CAST(0.2 AS DOUBLE) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql index 4b61b55158f5..ec72d90e4bac 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql @@ -56,7 +56,7 @@ WITH t0 AS ( AND t6.l_suppkey = t5.ps_suppkey AND t6.l_shipdate >= MAKE_DATE(1994, 1, 1) AND t6.l_shipdate < MAKE_DATE(1995, 1, 1) - ) * CAST(0.5 AS REAL(53)) + ) * CAST(0.5 AS DOUBLE) ) AS t4 ) ) diff --git a/ibis/backends/tests/tpch/snapshots/test_h22/test_tpc_h22/duckdb/h22.sql b/ibis/backends/tests/tpch/snapshots/test_h22/test_tpc_h22/duckdb/h22.sql index a9c96a30d190..76fa737c1056 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h22/test_tpc_h22/duckdb/h22.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h22/test_tpc_h22/duckdb/h22.sql @@ -25,7 +25,7 @@ WITH t0 AS ( AVG(t2.c_acctbal) AS avg_bal FROM main.customer AS t2 WHERE - t2.c_acctbal > CAST(0.0 AS REAL(53)) + t2.c_acctbal > CAST(0.0 AS DOUBLE) AND CASE WHEN ( CAST(0 AS TINYINT) + 1 >= 1