Skip to content

Commit

Permalink
feat(dask): implement mode aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Nov 3, 2022
1 parent fc023b5 commit 017f07a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
45 changes: 45 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import decimal
import numbers
from operator import methodcaller

import dask.array as da
import dask.dataframe as dd
Expand Down Expand Up @@ -214,6 +215,50 @@ def execute_arbitrary_series_groupby(op, data, _, aggcontext=None, **kwargs):
return aggcontext.agg(data, how)


def _mode_agg(df):
return df.sum().sort_values(ascending=False).index[0]


@execute_node.register(ops.Mode, dd.Series, (dd.Series, type(None)))
def execute_mode_series(_, data, mask, **kwargs):
if mask is not None:
data = data[mask]
return data.reduction(
chunk=methodcaller("value_counts"),
combine=methodcaller("sum"),
aggregate=_mode_agg,
meta=data.dtype,
)


def _grouped_mode_agg(gb):
return gb.obj.groupby(gb.obj.index.names).sum()


def _grouped_mode_finalize(series):
counts = "__counts__"
values = series.index.names[-1]
df = series.reset_index(-1, name=counts)
out = df.groupby(df.index.names).apply(
lambda g: g.sort_values(counts, ascending=False).iloc[0]
)
return out[values]


@execute_node.register(ops.Mode, ddgb.SeriesGroupBy, (ddgb.SeriesGroupBy, type(None)))
def execute_mode_series_group_by(_, data, mask, **kwargs):
if mask is not None:
data = data[mask]
return data.agg(
dd.Aggregation(
name="mode",
chunk=methodcaller("value_counts"),
agg=_grouped_mode_agg,
finalize=_grouped_mode_finalize,
)
)


@execute_node.register(ops.Cast, ddgb.SeriesGroupBy, dt.DataType)
def execute_cast_series_group_by(op, data, type, **kwargs):
result = execute_cast_series_generic(op, make_selected_obj(data), type, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def mean_udf(s):
marks=pytest.mark.notyet(
[
"clickhouse",
"dask",
"datafusion",
"impala",
"mysql",
Expand Down Expand Up @@ -273,7 +272,6 @@ def mean_and_std(v):
marks=pytest.mark.notyet(
[
"clickhouse",
"dask",
"datafusion",
"impala",
"mysql",
Expand Down

0 comments on commit 017f07a

Please sign in to comment.