Skip to content

Commit

Permalink
fix(pandas): grouped aggregation using a case statement
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed May 1, 2023
1 parent 433f7b7 commit d4ac345
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
66 changes: 48 additions & 18 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,34 +1402,64 @@ def wrap_case_result(raw, expr):
return result


@execute_node.register(ops.SearchedCase, tuple, tuple, object)
def execute_searched_case(op, whens, thens, otherwise, **kwargs):
whens = [execute(arg, **kwargs) for arg in whens]
thens = [execute(arg, **kwargs) for arg in thens]
def _build_select(op, whens, thens, otherwise, func=None, **kwargs):
if func is None:
func = lambda x: x

whens_ = []
grouped = 0
for when in whens:
res = execute(when, **kwargs)
obj = getattr(res, "obj", res)
grouped += obj is not res
whens_.append(obj)

thens_ = []
for then in thens:
res = execute(then, **kwargs)
obj = getattr(res, "obj", res)
grouped += obj is not res
thens_.append(obj)

if otherwise is None:
otherwise = np.nan
raw = np.select(whens, thens, otherwise)

raw = np.select(func(whens_), thens_, otherwise)

if grouped:
return pd.Series(raw).groupby(get_grouping(res.grouper.groupings))
return wrap_case_result(raw, op.to_expr())


@execute_node.register(ops.SearchedCase, tuple, tuple, object)
def execute_searched_case(op, whens, thens, otherwise, **kwargs):
return _build_select(op, whens, thens, otherwise, **kwargs)


@execute_node.register(ops.SimpleCase, object, tuple, tuple, object)
def execute_simple_case_scalar(op, value, whens, thens, otherwise, **kwargs):
whens = [execute(arg, **kwargs) for arg in whens]
thens = [execute(arg, **kwargs) for arg in thens]
if otherwise is None:
otherwise = np.nan
raw = np.select(np.asarray(whens) == value, thens, otherwise)
return wrap_case_result(raw, op.to_expr())
value = getattr(value, "obj", value)
return _build_select(
op,
whens,
thens,
otherwise,
func=lambda whens: np.asarray(whens) == value,
**kwargs,
)


@execute_node.register(ops.SimpleCase, pd.Series, tuple, tuple, object)
@execute_node.register(ops.SimpleCase, (pd.Series, SeriesGroupBy), tuple, tuple, object)
def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
whens = [execute(arg, **kwargs) for arg in whens]
thens = [execute(arg, **kwargs) for arg in thens]
if otherwise is None:
otherwise = np.nan
raw = np.select([value == when for when in whens], thens, otherwise)
return wrap_case_result(raw, op.to_expr())
value = getattr(value, "obj", value)
return _build_select(
op,
whens,
thens,
otherwise,
func=lambda whens: [value == when for when in whens],
**kwargs,
)


@execute_node.register(ops.Distinct, pd.DataFrame)
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlglot
from pytest import mark, param

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis import _
Expand Down Expand Up @@ -1347,3 +1348,15 @@ def test_agg_name_in_output_column(alltypes):
df = query.execute()
assert "min" in df.columns[0].lower()
assert "max" in df.columns[1].lower()


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end()

expr = table.group_by("key").aggregate(mx=case_expr.max()).order_by("key")
result = con.execute(expr)
expected = pd.DataFrame({"key": [1, 2], "mx": [10, 20]})
backend.assert_frame_equal(result, expected)

0 comments on commit d4ac345

Please sign in to comment.