Skip to content

Commit

Permalink
feat(api): support median and quantile on more types (#7810)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Dec 20, 2023
1 parent 14b82eb commit 49c75a8
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 76 deletions.
12 changes: 10 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,16 @@ def _array_remove(t, op):
ops.JSONGetItem: fixed_arity(_json_get_item, 2),
ops.RowID: lambda *_: sa.literal_column("rowid"),
ops.StringToTimestamp: _strptime,
ops.Quantile: reduction(sa.func.quantile_cont),
ops.MultiQuantile: reduction(sa.func.quantile_cont),
ops.Quantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.MultiQuantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.TypeOf: unary(sa.func.typeof),
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/oracle/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def _string_join(t, op):
return sa.func.concat(*toolz.interpose(sep, values))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

if arg.dtype.is_numeric():
return sa.func.median(t.translate(arg))
return sa.cast(
sa.func.percentile_disc(0.5).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


operation_registry.update(
{
ops.Log2: unary(lambda arg: sa.func.log(2, arg)),
Expand All @@ -96,7 +109,7 @@ def _string_join(t, op):
ops.Covariance: _cov,
ops.Correlation: _corr,
ops.ApproxMedian: reduction(sa.func.approx_median),
ops.Median: reduction(sa.func.median),
ops.Median: _median,
# Temporal
ops.ExtractSecond: _second,
# String
Expand Down
35 changes: 25 additions & 10 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def struct_column(op, **kw):
ops.Mean: "mean",
ops.Median: "median",
ops.Min: "min",
ops.Mode: "mode",
ops.StandardDev: "std",
ops.Sum: "sum",
ops.Variance: "var",
Expand All @@ -682,22 +683,36 @@ def struct_column(op, **kw):

@translate.register(reduction)
def reduction(op, **kw):
arg = translate(op.arg, **kw)
args = [
translate(arg, **kw)
for name, arg in zip(op.argnames, op.args)
if name not in ("where", "how")
]

agg = _reductions[type(op)]
filt = arg.is_not_null()

predicates = [arg.is_not_null() for arg in args]
if (where := op.where) is not None:
filt &= translate(where, **kw)
arg = arg.filter(filt)
method = getattr(arg, agg)
return method().cast(dtype_to_polars(op.dtype))
predicates.append(translate(where, **kw))

first, *rest = args
method = operator.methodcaller(agg, *rest)
return method(first.filter(reduce(operator.and_, predicates))).cast(
dtype_to_polars(op.dtype)
)

@translate.register(ops.Mode)
def mode(op, **kw):

@translate.register(ops.Quantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
quantile = translate(op.quantile, **kw)
filt = arg.is_not_null() & quantile.is_not_null()
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
return arg.mode().min()
filt &= translate(where, **kw)

# we can't throw quantile into the _reductions mapping because Polars'
# default interpolation of "nearest" doesn't match the rest of our backends
return arg.filter(filt).quantile(quantile, interpolation="linear")


@translate.register(ops.Correlation)
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,23 @@ def _quantile(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)
return sa.func.percentile_cont(t.translate(op.quantile)).within_group(
t.translate(arg)
)
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(t.translate(op.quantile)).within_group(t.translate(arg))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

return sa.func.percentile_cont(0.5).within_group(t.translate(arg))
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(0.5).within_group(t.translate(arg))


def _binary_variance_reduction(func):
Expand Down
111 changes: 104 additions & 7 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from datetime import date
from operator import methodcaller

import numpy as np
import pandas as pd
import pytest
Expand All @@ -23,11 +26,6 @@
except ImportError:
GoogleBadRequest = None

try:
from polars.exceptions import ComputeError
except ImportError:
ComputeError = None

try:
from clickhouse_connect.driver.exceptions import (
DatabaseError as ClickhouseDatabaseError,
Expand All @@ -40,12 +38,16 @@
except ImportError:
Py4JError = None


try:
from pyexasol.exceptions import ExaQueryError
except ImportError:
ExaQueryError = None

try:
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
except ImportError:
PolarsInvalidOperationError = None


@reduction(input_type=[dt.double], output_type=dt.double)
def mean_udf(s):
Expand Down Expand Up @@ -899,7 +901,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
"impala",
"mssql",
"mysql",
"polars",
"sqlite",
"druid",
"oracle",
Expand Down Expand Up @@ -1210,6 +1211,102 @@ def test_median(alltypes, df):
assert result == expected


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "trino", "exasol", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["clickhouse"],
raises=ClickhouseDatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(
["oracle"], raises=sa.exc.DatabaseError, reason="doesn't support median of strings"
)
@pytest.mark.broken(
["pyspark"], raises=AssertionError, reason="pyspark returns null for string median"
)
@pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError))
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.notimpl(
["pandas"], raises=TypeError, reason="results aren't correctly typed"
)
@pytest.mark.parametrize(
"func",
[
param(
methodcaller("quantile", 0.5),
id="quantile",
marks=[
pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
],
),
param(
methodcaller("median"),
id="median",
marks=[
pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError)
],
),
],
)
def test_string_quantile(alltypes, func):
expr = func(alltypes.select(col=ibis.literal("a")).limit(5).col)
result = expr.execute()
assert result == "a"


@pytest.mark.notimpl(["bigquery", "sqlite"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "trino", "exasol", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(["druid"], raises=AttributeError)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of dates",
)
@pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError))
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.broken(
["pandas"], raises=AssertionError, reason="possibly incorrect results"
)
@pytest.mark.parametrize(
"func",
[
param(
methodcaller("quantile", 0.5),
id="quantile",
marks=[
pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
],
),
param(
methodcaller("median"),
id="median",
marks=[
pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError)
],
),
],
)
def test_date_quantile(alltypes, func):
expr = func(alltypes.timestamp_col.date())
result = expr.execute()
assert result == date(2009, 12, 31)


@pytest.mark.parametrize(
("result_fn", "expected_fn"),
[
Expand Down
31 changes: 20 additions & 11 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,38 @@ def dtype(self):
return dt.higher_precedence(dtype, dt.float64)


@public
class Median(Filterable, Reduction):
arg: Column[dt.Numeric | dt.Boolean]
class QuantileBase(Filterable, Reduction):
arg: Column

@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype, dt.float64)
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dtype


@public
class Quantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Numeric]
class Median(QuantileBase):
pass

dtype = dt.float64

@public
class Quantile(QuantileBase):
quantile: Value[dt.Numeric]


@public
class MultiQuantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Array[dt.Float64]]
arg: Column
quantile: Value[dt.Array[dt.Numeric]]

dtype = dt.Array(dt.float64)
@attribute
def dtype(self):
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dt.Array(dtype)


@public
Expand Down
Loading

0 comments on commit 49c75a8

Please sign in to comment.