diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 39549adfc50f..b9ac998532a2 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -12,6 +12,7 @@ import sqlglot as sg from sqlglot import exp, transforms from sqlglot.dialects import Postgres +from sqlglot.dialects.dialect import rename_func import ibis import ibis.common.exceptions as com @@ -58,6 +59,7 @@ class Generator(Postgres.Generator): transforms.eliminate_qualify, ] ), + exp.IsNan: rename_func("isnan"), } diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index 0447e36da547..d59822fc0abd 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import math import operator from typing import Any @@ -11,7 +12,6 @@ import ibis.expr.operations as ops from ibis.backends.base.sqlglot import ( NULL, - STAR, AggGen, F, interval, @@ -202,6 +202,8 @@ def _literal(op, *, value, dtype, **kw): elif dtype.is_string() or dtype.is_macaddr(): return sg.exp.convert(str(value)) elif dtype.is_numeric(): + if isinstance(value, float) and math.isinf(value): + return sg.exp.Literal.number("'+Inf'::double") return sg.exp.convert(value) elif dtype.is_interval(): if dtype.unit.short in {"ms", "us", "ns"}: @@ -324,7 +326,7 @@ def count_distinct(op, *, arg, where, **_): @translate_val.register(ops.CountStar) def count_star(op, *, where, **_): - return agg.count(STAR, where=where) + return agg.count(1, where=where) @translate_val.register(ops.Sum) @@ -764,3 +766,13 @@ def correlation(op, *, left, right, where, **_): right = cast(right, dt.float64) return agg["corr"](left, right, where=where) + + +@translate_val.register(ops.IsNull) +def is_null(op, *, arg, **_): + return arg.is_(NULL) + + +@translate_val.register(ops.IsNan) +def is_nan(op, *, arg, **_): + return F.isnan(F.coalesce(arg, sg.exp.Literal.number("'NaN'::double"))) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 08fad24d604f..fb2e99ca2993 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -128,13 +128,13 @@ def test_scalar_fillna_nullif(con, expr, expected): param( "nan_col", _.nan_col.isnan(), - marks=pytest.mark.notimpl(["datafusion", "mysql", "sqlite"]), + marks=pytest.mark.notimpl(["mysql", "sqlite"]), id="nan_col", ), param( "none_col", _.none_col.isnull(), - marks=[pytest.mark.notimpl(["datafusion", "mysql"])], + marks=[pytest.mark.notimpl(["mysql"])], id="none_col", ), ], @@ -376,7 +376,7 @@ def test_case_where(backend, alltypes, df): # TODO: some of these are notimpl (datafusion) others are probably never -@pytest.mark.notimpl(["datafusion", "mysql", "sqlite", "mssql", "druid", "oracle"]) +@pytest.mark.notimpl(["mysql", "sqlite", "mssql", "druid", "oracle"]) @pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError) def test_select_filter_mutate(backend, alltypes, df): """Test that select, filter and mutate are executed in right order. @@ -565,7 +565,6 @@ def test_order_by_random(alltypes): raises=sa.exc.ProgrammingError, reason="Druid only supports trivial unions", ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_table_info(alltypes): expr = alltypes.info() df = expr.execute() @@ -1372,11 +1371,6 @@ def test_try_cast_func(con, from_val, to_type, func): raises=BadRequest, reason="bigquery doesn't support OFFSET without LIMIT", ), - pytest.mark.notyet( - ["datafusion"], - raises=AssertionError, - reason="no support for offset yet", - ), pytest.mark.notyet( ["mssql"], raises=sa.exc.CompileError, @@ -1401,11 +1395,6 @@ def test_try_cast_func(con, from_val, to_type, func): lambda _: 1, id="[3:4]", marks=[ - pytest.mark.notyet( - ["datafusion"], - raises=AssertionError, - reason="no support for offset yet", - ), pytest.mark.notyet( ["mssql"], raises=sa.exc.CompileError, diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 3537a6a9080a..8fe6c8084b76 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -753,11 +753,17 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): operator.methodcaller("isinf"), np.isinf, id="isinf", + marks=[ + pytest.mark.notimpl( + ["datafusion"], + raises=com.OperationNotDefinedError, + ) + ], ), ], ) @pytest.mark.notimpl( - ["mysql", "sqlite", "datafusion", "mssql", "oracle", "flink"], + ["mysql", "sqlite", "mssql", "oracle", "flink"], raises=com.OperationNotDefinedError, ) @pytest.mark.xfail(