Skip to content

Commit

Permalink
feat(api): support median and quantile on more types
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 20, 2023
1 parent 0821bb4 commit 3537ec0
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 65 deletions.
12 changes: 10 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,16 @@ def _array_remove(t, op):
ops.JSONGetItem: fixed_arity(_json_get_item, 2),
ops.RowID: lambda *_: sa.literal_column("rowid"),
ops.StringToTimestamp: _strptime,
ops.Quantile: reduction(sa.func.quantile_cont),
ops.MultiQuantile: reduction(sa.func.quantile_cont),
ops.Quantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.MultiQuantile: lambda t, op: (
reduction(sa.func.quantile_cont)(t, op)
if op.arg.dtype.is_numeric()
else reduction(sa.func.quantile_disc)(t, op)
),
ops.TypeOf: unary(sa.func.typeof),
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/oracle/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def _string_join(t, op):
return sa.func.concat(*toolz.interpose(sep, values))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

if arg.dtype.is_numeric():
return sa.func.median(t.translate(arg))
return sa.cast(
sa.func.percentile_disc(0.5).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


operation_registry.update(
{
ops.Log2: unary(lambda arg: sa.func.log(2, arg)),
Expand All @@ -96,7 +109,7 @@ def _string_join(t, op):
ops.Covariance: _cov,
ops.Correlation: _corr,
ops.ApproxMedian: reduction(sa.func.approx_median),
ops.Median: reduction(sa.func.median),
ops.Median: _median,
# Temporal
ops.ExtractSecond: _second,
# String
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,23 @@ def _quantile(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)
return sa.func.percentile_cont(t.translate(op.quantile)).within_group(
t.translate(arg)
)
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(t.translate(op.quantile)).within_group(t.translate(arg))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

return sa.func.percentile_cont(0.5).within_group(t.translate(arg))
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(0.5).within_group(t.translate(arg))


def _binary_variance_reduction(func):
Expand Down
78 changes: 72 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from datetime import date
from operator import methodcaller

import numpy as np
import pandas as pd
import pytest
Expand All @@ -23,11 +26,6 @@
except ImportError:
GoogleBadRequest = None

try:
from polars.exceptions import ComputeError
except ImportError:
ComputeError = None

try:
from clickhouse_connect.driver.exceptions import (
DatabaseError as ClickhouseDatabaseError,
Expand All @@ -40,12 +38,16 @@
except ImportError:
Py4JError = None


try:
from pyexasol.exceptions import ExaQueryError
except ImportError:
ExaQueryError = None

try:
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
except ImportError:
PolarsInvalidOperationError = None


@reduction(input_type=[dt.double], output_type=dt.double)
def mean_udf(s):
Expand Down Expand Up @@ -1210,6 +1212,70 @@ def test_median(alltypes, df):
assert result == expected


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["clickhouse"],
raises=ClickhouseDatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(
["oracle"],
raises=sa.exc.DatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.notimpl(
["pandas"], raises=TypeError, reason="results aren't correctly typed"
)
@pytest.mark.parametrize(
"func", [methodcaller("quantile", 0.5), methodcaller("median")]
)
def test_string_quantile(alltypes, func):
expr = func(alltypes.select(col=ibis.literal("a")).limit(5).col)
result = expr.execute()
assert result == "a"


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of dates",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.broken(
["pandas"], raises=AssertionError, reason="possibly incorrect results"
)
@pytest.mark.parametrize(
"func", [methodcaller("quantile", 0.5), methodcaller("median")]
)
def test_date_quantile(alltypes, func):
expr = func(alltypes.timestamp_col.date())
result = expr.execute()
assert result == date(2009, 12, 31)


@pytest.mark.parametrize(
("result_fn", "expected_fn"),
[
Expand Down
31 changes: 20 additions & 11 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,38 @@ def dtype(self):
return dt.higher_precedence(dtype, dt.float64)


@public
class Median(Filterable, Reduction):
arg: Column[dt.Numeric | dt.Boolean]
class QuantileBase(Filterable, Reduction):
arg: Column

@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype, dt.float64)
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dtype


@public
class Quantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Numeric]
class Median(QuantileBase):
pass

dtype = dt.float64

@public
class Quantile(QuantileBase):
quantile: Value[dt.Numeric]


@public
class MultiQuantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Array[dt.Float64]]
arg: Column
quantile: Value[dt.Array[dt.Numeric]]

dtype = dt.Array(dt.float64)
@attribute
def dtype(self):
dtype = self.arg.dtype
if dtype.is_numeric():
dtype = dt.higher_precedence(dtype, dt.float64)
return dt.Array(dtype)


@public
Expand Down
42 changes: 42 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,48 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
self, key=key, where=self._bind_reduction_filter(where)
).to_expr()

def median(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
Scalar
Median of the column
"""
return ops.Median(self, where=where).to_expr()

def quantile(
self,
quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float],
where: ir.BooleanValue | None = None,
) -> Scalar:
"""Return value at the given quantile.
Parameters
----------
quantile
`0 <= quantile <= 1`, or an array of such values
indicating the quantile or quantiles to compute
where
Boolean filter for input values
Returns
-------
Scalar
Quantile of the input
"""
if isinstance(quantile, Sequence):
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=where).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of distinct rows in an expression.
Expand Down
41 changes: 0 additions & 41 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,47 +755,6 @@ class NumericScalar(Scalar, NumericValue):

@public
class NumericColumn(Column, NumericValue):
def median(self, where: ir.BooleanValue | None = None) -> NumericScalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
NumericScalar
Median of the column
"""
return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr()

def quantile(
self,
quantile: Sequence[NumericValue | float],
where: ir.BooleanValue | None = None,
) -> NumericScalar:
"""Return value at the given quantile.
Parameters
----------
quantile
`0 <= quantile <= 1`, the quantile(s) to compute
where
Boolean filter for input values
Returns
-------
NumericScalar
Quantile of the input
"""
if isinstance(quantile, collections.abc.Sequence):
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr()

def std(
self,
where: ir.BooleanValue | None = None,
Expand Down

0 comments on commit 3537ec0

Please sign in to comment.