Skip to content

Commit

Permalink
fix(duckdb): ensure that casting to floating point values produces va…
Browse files Browse the repository at this point in the history
…lid types in generated sql
  • Loading branch information
cpcloud authored and kszucs committed Jan 22, 2024
1 parent 7750a1a commit 424b206
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 15 deletions.
13 changes: 13 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def compiles_array(element, compiler, **kw):
return f"ARRAY({compiler.process(element.value_type, **kw)})"


@compiles(sat.FLOAT, "duckdb")
def compiles_float(element, compiler, **kw):
precision = element.precision
if precision is None or 1 <= precision <= 24:
return "FLOAT"
elif 24 < precision <= 53:
return "DOUBLE"
else:
raise ValueError(
"FLOAT precision must be between 1 and 53 inclusive, or `None`"
)


class StructType(sat.UserDefinedType):
cache_ok = True

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT
CAST('1.0' AS REAL) AS "Cast('1.0', float32)"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT
CAST('1.0' AS DOUBLE) AS "Cast('1.0', float64)"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS REAL(53))) AS tmp
ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS DOUBLE)) AS tmp
FROM t AS t0
23 changes: 15 additions & 8 deletions ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import duckdb_engine
import numpy as np
import pytest
import sqlalchemy as sa
from packaging.version import parse as vparse
from pytest import param

import ibis
import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -66,17 +68,13 @@ def test_parser(typ, expected):

@pytest.mark.parametrize("uint_type", ["uint8", "uint16", "uint32", "uint64"])
def test_cast_uints(uint_type, snapshot):
import ibis

t = ibis.table(dict(a="int8"), name="t")
snapshot.assert_match(
str(ibis.to_sql(t.a.cast(uint_type), dialect="duckdb")), "out.sql"
)


def test_null_dtype():
import ibis

con = ibis.connect("duckdb://:memory:")

t = ibis.memtable({"a": [None, None]})
Expand Down Expand Up @@ -110,10 +108,6 @@ def test_generate_quoted_struct():
reason="mapping from UINTEGER query metadata fixed in 0.9.2",
)
def test_read_uint8_from_parquet(tmp_path):
import numpy as np

import ibis

con = ibis.duckdb.connect()

# There is an incorrect mapping in duckdb-engine from UInteger -> UInt8
Expand All @@ -129,3 +123,16 @@ def test_read_uint8_from_parquet(tmp_path):
t2 = con.read_parquet(parqpath)

assert t2.schema() == t.schema()


@pytest.mark.parametrize("typ", ["float32", "float64"])
def test_cast_to_floating_point_type(con, snapshot, typ):
expected = 1.0
value = ibis.literal(str(expected))
expr = value.cast(typ)

result = con.execute(expr)
assert result == expected

sql = str(ibis.to_sql(expr, dialect="duckdb"))
snapshot.assert_match(sql, "out.sql")
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ FROM main.lineitem AS t0
WHERE
t0.l_shipdate >= MAKE_DATE(1994, 1, 1)
AND t0.l_shipdate < MAKE_DATE(1995, 1, 1)
AND t0.l_discount BETWEEN CAST(0.05 AS REAL(53)) AND CAST(0.07 AS REAL(53))
AND t0.l_discount BETWEEN CAST(0.05 AS DOUBLE) AND CAST(0.07 AS DOUBLE)
AND t0.l_quantity < CAST(24 AS TINYINT)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ FROM (
WHERE
t4.n_name = 'GERMANY'
) AS anon_1
) * CAST(0.0001 AS REAL(53))
) * CAST(0.0001 AS DOUBLE)
) AS t1
ORDER BY
t1.value DESC
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
SUM(t0.l_extendedprice) / CAST(7.0 AS REAL(53)) AS avg_yearly
SUM(t0.l_extendedprice) / CAST(7.0 AS DOUBLE) AS avg_yearly
FROM main.lineitem AS t0
JOIN main.part AS t1
ON t1.p_partkey = t0.l_partkey
Expand All @@ -12,4 +12,4 @@ WHERE
FROM main.lineitem AS t0
WHERE
t0.l_partkey = t1.p_partkey
) * CAST(0.2 AS REAL(53))
) * CAST(0.2 AS DOUBLE)
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ WITH t0 AS (
AND t6.l_suppkey = t5.ps_suppkey
AND t6.l_shipdate >= MAKE_DATE(1994, 1, 1)
AND t6.l_shipdate < MAKE_DATE(1995, 1, 1)
) * CAST(0.5 AS REAL(53))
) * CAST(0.5 AS DOUBLE)
) AS t4
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ WITH t0 AS (
AVG(t2.c_acctbal) AS avg_bal
FROM main.customer AS t2
WHERE
t2.c_acctbal > CAST(0.0 AS REAL(53))
t2.c_acctbal > CAST(0.0 AS DOUBLE)
AND CASE
WHEN (
CAST(0 AS TINYINT) + 1 >= 1
Expand Down

0 comments on commit 424b206

Please sign in to comment.