diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index d2713816c972..a37ca10c8b63 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -694,3 +694,19 @@ def _window_boundary(op, *, value, preceding, **_): # TODO: bit of a hack to return a dict, but there's no sqlglot expression # that corresponds to _only_ this information return {"value": value, "side": "preceding" if preceding else "following"} + + +@translate_val.register(ops.SimpleCase) +@translate_val.register(ops.SearchedCase) +def _case(op, *, base=None, cases, results, default, **_): + return sg.exp.Case(this=base, ifs=list(map(if_, cases, results)), default=default) + + +@translate_val.register(ops.IfElse) +def _if_else(op, *, bool_expr, true_expr, false_null_expr, **_): + return if_(bool_expr, true_expr, false_null_expr) + + +@translate_val.register(ops.NotNull) +def _not_null(op, *, arg, **_): + return sg.not_(arg.is_(NULL)) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 5a3b621f22cc..308ce55795dc 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1345,7 +1345,7 @@ def collect_udf(v): backend.assert_frame_equal(result, expected, check_like=True) -@pytest.mark.notimpl(["datafusion", "pyspark"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError) def test_binds_are_cast(alltypes): expr = alltypes.aggregate( high_line_count=( @@ -1393,7 +1393,6 @@ def test_agg_name_in_output_column(alltypes): assert "max" in df.columns[1].lower() -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index d93e40b8c4bf..7d3a13005610 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -332,7 +332,6 @@ def test_filter_with_window_op(backend, alltypes, sorted_df): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["datafusion"]) def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( @@ -480,7 +479,6 @@ def test_dropna_invalid(alltypes): @pytest.mark.parametrize( "subset", [None, [], "col_1", ["col_1", "col_2"], ["col_1", "col_3"]] ) -@pytest.mark.notimpl(["datafusion"]) def test_dropna_table(backend, alltypes, how, subset): is_two = alltypes.int_col == 2 is_four = alltypes.int_col == 4 @@ -684,7 +682,6 @@ def test_zeroifnull_column(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["datafusion"]) def test_ifelse_select(backend, alltypes, df): table = alltypes table = table.select( @@ -708,7 +705,6 @@ def test_ifelse_select(backend, alltypes, df): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["datafusion"]) def test_ifelse_column(backend, alltypes, df): expr = ibis.ifelse(alltypes["int_col"] == 0, 42, -1).cast("int64").name("where_col") result = expr.execute() diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 1052582f9567..e64abde8a2c1 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -532,7 +532,7 @@ def uses_java_re(t): # pyspark doesn't support `cases` yet marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pyspark"], + ["dask", "pyspark"], raises=com.OperationNotDefinedError, ), pytest.mark.broken( @@ -866,7 +866,6 @@ def test_re_replace_global(con): assert result == "cbc" -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["mssql"], raises=sa.exc.OperationalError, @@ -1010,7 +1009,7 @@ def test_array_string_join(con): @pytest.mark.notimpl( - ["datafusion", "mssql", "mysql", "pyspark", "druid", "oracle"], + ["mssql", "mysql", "pyspark", "druid", "oracle"], raises=com.OperationNotDefinedError, ) def test_subs_with_re_replace(con): @@ -1019,7 +1018,7 @@ def test_subs_with_re_replace(con): assert result == "k" -@pytest.mark.notimpl(["pyspark", "datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError) def test_multiple_subs(con): m = {"foo": "FOO", "bar": "BAR"} expr = ibis.literal("foo").substitute(m) @@ -1053,7 +1052,6 @@ def test_levenshtein(con, right): assert result == 3 -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notyet( ["mssql"], reason="doesn't allow boolean expressions in select statements",