Skip to content

Commit

Permalink
feat(polars): implement quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 20, 2023
1 parent 34c5271 commit adb604f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
35 changes: 25 additions & 10 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def struct_column(op, **kw):
ops.Mean: "mean",
ops.Median: "median",
ops.Min: "min",
ops.Mode: "mode",
ops.StandardDev: "std",
ops.Sum: "sum",
ops.Variance: "var",
Expand All @@ -682,22 +683,36 @@ def struct_column(op, **kw):

@translate.register(reduction)
def reduction(op, **kw):
arg = translate(op.arg, **kw)
args = [
translate(arg, **kw)
for name, arg in zip(op.argnames, op.args)
if name not in ("where", "how")
]

agg = _reductions[type(op)]
filt = arg.is_not_null()

predicates = [arg.is_not_null() for arg in args]
if (where := op.where) is not None:
filt &= translate(where, **kw)
arg = arg.filter(filt)
method = getattr(arg, agg)
return method().cast(dtype_to_polars(op.dtype))
predicates.append(translate(where, **kw))

first, *rest = args
method = operator.methodcaller(agg, *rest)
return method(first.filter(reduce(operator.and_, predicates))).cast(
dtype_to_polars(op.dtype)
)

@translate.register(ops.Mode)
def mode(op, **kw):

@translate.register(ops.Quantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
quantile = translate(op.quantile, **kw)
filt = arg.is_not_null() & quantile.is_not_null()
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
return arg.mode().min()
filt &= translate(where, **kw)

# we can't throw quantile into the _reductions mapping because Polars'
# default interpolation of "nearest" doesn't match the rest of our backends
return arg.filter(filt).quantile(quantile, interpolation="linear")


@translate.register(ops.Correlation)
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
"impala",
"mssql",
"mysql",
"polars",
"sqlite",
"druid",
"oracle",
Expand Down

0 comments on commit adb604f

Please sign in to comment.