Skip to content

Commit

Permalink
feat(datafusion): add case and if-else statements
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Oct 27, 2023
1 parent 3206dbc commit 851d560
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
16 changes: 16 additions & 0 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,19 @@ def _window_boundary(op, *, value, preceding, **_):
# TODO: bit of a hack to return a dict, but there's no sqlglot expression
# that corresponds to _only_ this information
return {"value": value, "side": "preceding" if preceding else "following"}


@translate_val.register(ops.SimpleCase)
@translate_val.register(ops.SearchedCase)
def _case(op, *, base=None, cases, results, default, **_):
return sg.exp.Case(this=base, ifs=list(map(if_, cases, results)), default=default)


@translate_val.register(ops.IfElse)
def _if_else(op, *, bool_expr, true_expr, false_null_expr, **_):
return if_(bool_expr, true_expr, false_null_expr)


@translate_val.register(ops.NotNull)
def _not_null(op, *, arg, **_):
return sg.not_(arg.is_(NULL))
3 changes: 1 addition & 2 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def collect_udf(v):
backend.assert_frame_equal(result, expected, check_like=True)


@pytest.mark.notimpl(["datafusion", "pyspark"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError)
def test_binds_are_cast(alltypes):
expr = alltypes.aggregate(
high_line_count=(
Expand Down Expand Up @@ -1393,7 +1393,6 @@ def test_agg_name_in_output_column(alltypes):
assert "max" in df.columns[1].lower()


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def test_filter_with_window_op(backend, alltypes, sorted_df):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_case_where(backend, alltypes, df):
table = alltypes
table = table.mutate(
Expand Down Expand Up @@ -480,7 +479,6 @@ def test_dropna_invalid(alltypes):
@pytest.mark.parametrize(
"subset", [None, [], "col_1", ["col_1", "col_2"], ["col_1", "col_3"]]
)
@pytest.mark.notimpl(["datafusion"])
def test_dropna_table(backend, alltypes, how, subset):
is_two = alltypes.int_col == 2
is_four = alltypes.int_col == 4
Expand Down Expand Up @@ -684,7 +682,6 @@ def test_zeroifnull_column(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_ifelse_select(backend, alltypes, df):
table = alltypes
table = table.select(
Expand All @@ -708,7 +705,6 @@ def test_ifelse_select(backend, alltypes, df):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_ifelse_column(backend, alltypes, df):
expr = ibis.ifelse(alltypes["int_col"] == 0, 42, -1).cast("int64").name("where_col")
result = expr.execute()
Expand Down
8 changes: 3 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def uses_java_re(t):
# pyspark doesn't support `cases` yet
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pyspark"],
["dask", "pyspark"],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken(
Expand Down Expand Up @@ -866,7 +866,6 @@ def test_re_replace_global(con):
assert result == "cbc"


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["mssql"],
raises=sa.exc.OperationalError,
Expand Down Expand Up @@ -1010,7 +1009,7 @@ def test_array_string_join(con):


@pytest.mark.notimpl(
["datafusion", "mssql", "mysql", "pyspark", "druid", "oracle"],
["mssql", "mysql", "pyspark", "druid", "oracle"],
raises=com.OperationNotDefinedError,
)
def test_subs_with_re_replace(con):
Expand All @@ -1019,7 +1018,7 @@ def test_subs_with_re_replace(con):
assert result == "k"


@pytest.mark.notimpl(["pyspark", "datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError)
def test_multiple_subs(con):
m = {"foo": "FOO", "bar": "BAR"}
expr = ibis.literal("foo").substitute(m)
Expand Down Expand Up @@ -1053,7 +1052,6 @@ def test_levenshtein(con, right):
assert result == 3


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["mssql"],
reason="doesn't allow boolean expressions in select statements",
Expand Down

0 comments on commit 851d560

Please sign in to comment.