Skip to content

Commit

Permalink
feat(api): support median and quantile on more types
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 20, 2023
1 parent 0821bb4 commit 999f666
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 65 deletions.
12 changes: 10 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,16 @@ def _array_remove(t, op):
ops.JSONGetItem: fixed_arity(_json_get_item, 2),
ops.RowID: lambda *_: sa.literal_column("rowid"),
ops.StringToTimestamp: _strptime,
ops.Quantile: reduction(sa.func.quantile_cont),
ops.MultiQuantile: reduction(sa.func.quantile_cont),
ops.Quantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.MultiQuantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.TypeOf: unary(sa.func.typeof),
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/oracle/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def _string_join(t, op):
return sa.func.concat(*toolz.interpose(sep, values))


def _median(t, op):
arg = op.arg

Check warning on line 85 in ibis/backends/oracle/registry.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/oracle/registry.py#L85

Added line #L85 was not covered by tests
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

Check warning on line 87 in ibis/backends/oracle/registry.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/oracle/registry.py#L87

Added line #L87 was not covered by tests

if arg.dtype.is_numeric():
return sa.func.median(t.translate(arg))
return sa.cast(

Check warning on line 91 in ibis/backends/oracle/registry.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/oracle/registry.py#L90-L91

Added lines #L90 - L91 were not covered by tests
sa.func.percentile_disc(0.5).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


operation_registry.update(
{
ops.Log2: unary(lambda arg: sa.func.log(2, arg)),
Expand All @@ -96,7 +109,7 @@ def _string_join(t, op):
ops.Covariance: _cov,
ops.Correlation: _corr,
ops.ApproxMedian: reduction(sa.func.approx_median),
ops.Median: reduction(sa.func.median),
ops.Median: _median,
# Temporal
ops.ExtractSecond: _second,
# String
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,23 @@ def _quantile(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)
return sa.func.percentile_cont(t.translate(op.quantile)).within_group(
t.translate(arg)
)
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(t.translate(op.quantile)).within_group(t.translate(arg))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

return sa.func.percentile_cont(0.5).within_group(t.translate(arg))
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(0.5).within_group(t.translate(arg))


def _binary_variance_reduction(func):
Expand Down
78 changes: 72 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from datetime import date
from operator import methodcaller

import numpy as np
import pandas as pd
import pytest
Expand All @@ -23,11 +26,6 @@
except ImportError:
GoogleBadRequest = None

try:
from polars.exceptions import ComputeError
except ImportError:
ComputeError = None

try:
from clickhouse_connect.driver.exceptions import (
DatabaseError as ClickhouseDatabaseError,
Expand All @@ -40,12 +38,16 @@
except ImportError:
Py4JError = None


try:
from pyexasol.exceptions import ExaQueryError
except ImportError:
ExaQueryError = None

try:
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
except ImportError:
PolarsInvalidOperationError = None


@reduction(input_type=[dt.double], output_type=dt.double)
def mean_udf(s):
Expand Down Expand Up @@ -1210,6 +1212,70 @@ def test_median(alltypes, df):
assert result == expected


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["clickhouse"],
raises=ClickhouseDatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(
["oracle"],
raises=sa.exc.DatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.notimpl(
["pandas"], raises=TypeError, reason="results aren't correctly typed"
)
@pytest.mark.parametrize(
"func", [methodcaller("quantile", 0.5), methodcaller("median")]
)
def test_string_quantile(alltypes, func):
expr = func(alltypes.select(col=ibis.literal("a")).limit(5).col)
result = expr.execute()
assert result == "a"


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of dates",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.broken(
["pandas"], raises=AssertionError, reason="possibly incorrect results"
)
@pytest.mark.parametrize(
"func", [methodcaller("quantile", 0.5), methodcaller("median")]
)
def test_date_quantile(alltypes, func):
expr = func(alltypes.timestamp_col.date())
result = expr.execute()
assert result == date(2009, 12, 31)


@pytest.mark.parametrize(
("result_fn", "expected_fn"),
[
Expand Down
31 changes: 20 additions & 11 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,38 @@ def dtype(self):
return dt.higher_precedence(dtype, dt.float64)


@public
class Median(Filterable, Reduction):
arg: Column[dt.Numeric | dt.Boolean]
class QuantileBase(Filterable, Reduction):
arg: Column

@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype, dt.float64)
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dtype


@public
class Quantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Numeric]
class Median(QuantileBase):
pass

dtype = dt.float64

@public
class Quantile(QuantileBase):
quantile: Value[dt.Numeric]


@public
class MultiQuantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Array[dt.Float64]]
arg: Column
quantile: Value[dt.Array[dt.Numeric]]

dtype = dt.Array(dt.float64)
@attribute
def dtype(self):
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dt.Array(dtype)


@public
Expand Down
125 changes: 125 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,131 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
self, key=key, where=self._bind_reduction_filter(where)
).to_expr()

def median(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
Scalar
Median of the column
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.examples.penguins.fetch()
Compute the median of `bill_depth_mm`
>>> t.bill_depth_mm.median()
17.3
>>> t.group_by(t.species).agg(
... median_bill_depth=t.bill_depth_mm.median()
... ).order_by(ibis.desc("median_bill_depth"))
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ species ┃ median_bill_depth ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ string │ float64 │
├───────────┼───────────────────┤
│ Chinstrap │ 18.45 │
│ Adelie │ 18.40 │
│ Gentoo │ 15.00 │
└───────────┴───────────────────┘
In addition to numeric types, any orderable non-numeric types such as
strings and dates work with `median`.
>>> t.group_by(t.island).agg(
... median_species=t.species.median()
... ).order_by(ibis.desc("median_species"))
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃ island ┃ median_species ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│ string │ string │
├───────────┼────────────────┤
│ Biscoe │ Gentoo │
│ Dream │ Chinstrap │
│ Torgersen │ Adelie │
└───────────┴────────────────┘
"""
return ops.Median(self, where=where).to_expr()

def quantile(
self,
quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float],
where: ir.BooleanValue | None = None,
) -> Scalar:
"""Return value at the given quantile.
The output of this method is a continuous quantile if the input is
numeric, otherwise the output is a discrete quantile.
Parameters
----------
quantile
`0 <= quantile <= 1`, or an array of such values
indicating the quantile or quantiles to compute
where
Boolean filter for input values
Returns
-------
Scalar
Quantile of the input
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.examples.penguins.fetch()
Compute the 99th percentile of `bill_depth`
>>> t.bill_depth_mm.quantile(0.99)
21.1
>>> t.group_by(t.species).agg(
... p99_bill_depth=t.bill_depth_mm.quantile(0.99)
... ).order_by(ibis.desc("p99_bill_depth"))
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃ species ┃ p99_bill_depth ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│ string │ float64 │
├───────────┼────────────────┤
│ Adelie │ 21.200 │
│ Chinstrap │ 20.733 │
│ Gentoo │ 17.256 │
└───────────┴────────────────┘
In addition to numeric types, any orderable non-numeric types such as
strings and dates work with `quantile`.
Let's compute the 99th percentile of the `species` column
>>> t.group_by(t.island).agg(
... p99_species=t.species.quantile(0.99)
... ).order_by(ibis.desc("p99_species"))
┏━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ island ┃ p99_species ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ string │ string │
├───────────┼─────────────┤
│ Biscoe │ Gentoo │
│ Dream │ Chinstrap │
│ Torgersen │ Adelie │
└───────────┴─────────────┘
"""
if isinstance(quantile, Sequence):
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=where).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of distinct rows in an expression.
Expand Down
Loading

0 comments on commit 999f666

Please sign in to comment.