Skip to content

Commit

Permalink
feat(duckdb): implement mode aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Nov 3, 2022
1 parent 4ce9f13 commit 36fd152
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
1 change: 1 addition & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _struct_column(t, op):
),
ops.HLLCardinality: reduction(sa.func.approx_count_distinct),
ops.ApproxCountDistinct: reduction(sa.func.approx_count_distinct),
ops.Mode: reduction(sa.func.mode),
ops.Strftime: _strftime,
ops.Arbitrary: _arbitrary,
ops.GroupConcat: _string_agg,
Expand Down
43 changes: 40 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ def mean_udf(s):
lambda t: t.double_col.max(),
id='max',
),
param(
# int_col % 3 so there are no ties for most common value
lambda t: (t.int_col % 3).mode(),
lambda t: (t.int_col % 3).mode().iloc[0],
id='mode',
marks=pytest.mark.notyet(
[
"clickhouse",
"dask",
"datafusion",
"impala",
"mysql",
"pandas",
"polars",
"pyspark",
"sqlite",
]
),
),
param(
lambda t: (t.double_col + 5).sum(),
lambda t: (t.double_col + 5).sum(),
Expand Down Expand Up @@ -247,6 +266,27 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].max(),
id='max',
),
param(
# int_col % 3 so there are no ties for most common value
lambda t, where: (t.int_col % 3).mode(where=where),
lambda t, where: (t.int_col % 3)[where].mode().iloc[0],
id='mode',
marks=pytest.mark.notyet(
[
"clickhouse",
"dask",
"datafusion",
"impala",
"mysql",
"pandas",
"polars",
"postgres",
"pyspark",
"snowflake",
"sqlite",
]
),
),
param(
lambda t, where: t.double_col.argmin(t.int_col, where=where),
lambda t, where: t.double_col[where].iloc[t.int_col[where].argmin()],
Expand Down Expand Up @@ -387,9 +427,6 @@ def mean_and_std(v):
lambda t, where: t.count(where=where),
lambda t, where: len(t[where]),
id='count_star',
marks=[
# pytest.mark.notimpl(["polars"]),
],
),
],
)
Expand Down
6 changes: 6 additions & 0 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ class Covariance(Filterable, Reduction):
output_dtype = dt.float64


@public
class Mode(Filterable, Reduction):
arg = rlz.column(rlz.any)
output_dtype = rlz.dtype_like('arg')


@public
class Max(Filterable, Reduction):
arg = rlz.column(rlz.any)
Expand Down
4 changes: 4 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ def approx_median(
"""
return ops.ApproxMedian(self, where).to_expr().name("approx_median")

def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the mode of a column."""
return ops.Mode(self, where).to_expr().name("mode")

def max(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the maximum of a column."""
return ops.Max(self, where).to_expr().name("max")
Expand Down

0 comments on commit 36fd152

Please sign in to comment.