Skip to content

Commit

Permalink
feat(api): support order statistics on more types
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 20, 2023
1 parent 0821bb4 commit 51119e3
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 67 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ def _array_remove(t, op):
ops.JSONGetItem: fixed_arity(_json_get_item, 2),
ops.RowID: lambda *_: sa.literal_column("rowid"),
ops.StringToTimestamp: _strptime,
ops.Quantile: reduction(sa.func.quantile_cont),
ops.MultiQuantile: reduction(sa.func.quantile_cont),
ops.Quantile: reduction(sa.func.quantile),
ops.MultiQuantile: reduction(sa.func.quantile),
ops.TypeOf: unary(sa.func.typeof),
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/oracle/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def _string_join(t, op):
return sa.func.concat(*toolz.interpose(sep, values))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

if arg.dtype.is_numeric():
return sa.func.median(t.translate(arg))
return sa.cast(
sa.func.percentile_disc(0.5).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


operation_registry.update(
{
ops.Log2: unary(lambda arg: sa.func.log(2, arg)),
Expand All @@ -96,7 +109,7 @@ def _string_join(t, op):
ops.Covariance: _cov,
ops.Correlation: _corr,
ops.ApproxMedian: reduction(sa.func.approx_median),
ops.Median: reduction(sa.func.median),
ops.Median: _median,
# Temporal
ops.ExtractSecond: _second,
# String
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,23 @@ def _quantile(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)
return sa.func.percentile_cont(t.translate(op.quantile)).within_group(
t.translate(arg)
)
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(t.translate(op.quantile)).within_group(t.translate(arg))


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

return sa.func.percentile_cont(0.5).within_group(t.translate(arg))
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return func(0.5).within_group(t.translate(arg))


def _binary_variance_reduction(func):
Expand Down
30 changes: 29 additions & 1 deletion ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,33 @@ def _timestamp_range(t, op):
)


def _quantile(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)
if arg.dtype.is_numeric():
func = sa.func.percentile_cont
else:
func = sa.func.percentile_disc
return sa.cast(
func(t.translate(op.quantile)).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.IfElse(where, arg, None)

if arg.dtype.is_numeric():
return sa.func.median(t.translate(arg))
return sa.cast(
sa.func.percentile_disc(0.5).within_group(t.translate(arg)),
t.get_sqla_type(op.dtype),
)


_TIMESTAMP_UNITS_TO_SCALE = {"s": 0, "ms": 3, "us": 6, "ns": 9}

_SF_POS_INF = sa.func.to_double("Inf")
Expand Down Expand Up @@ -542,7 +569,8 @@ def _timestamp_range(t, op):
ops.GroupConcat: _group_concat,
ops.Hash: unary(sa.func.hash),
ops.ApproxMedian: reduction(lambda x: sa.func.approx_percentile(x, 0.5)),
ops.Median: reduction(sa.func.median),
ops.Quantile: _quantile,
ops.Median: _median,
ops.TableColumn: _table_column,
ops.Levenshtein: fixed_arity(sa.func.editdistance, 2),
ops.TimeDelta: fixed_arity(
Expand Down
64 changes: 58 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from datetime import date

import numpy as np
import pandas as pd
import pytest
Expand All @@ -23,11 +25,6 @@
except ImportError:
GoogleBadRequest = None

try:
from polars.exceptions import ComputeError
except ImportError:
ComputeError = None

try:
from clickhouse_connect.driver.exceptions import (
DatabaseError as ClickhouseDatabaseError,
Expand All @@ -40,7 +37,6 @@
except ImportError:
Py4JError = None


try:
from pyexasol.exceptions import ExaQueryError
except ImportError:
Expand Down Expand Up @@ -1210,6 +1206,62 @@ def test_median(alltypes, df):
assert result == expected


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["clickhouse"],
raises=ClickhouseDatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(
["oracle"],
raises=sa.exc.DatabaseError,
reason="doesn't support median of strings",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of strings",
)
@pytest.mark.broken(["polars"], raises=AssertionError, reason="incorrect results")
@pytest.mark.notimpl(
["pandas"], raises=TypeError, reason="results aren't correctly typed"
)
def test_string_median(alltypes):
expr = alltypes.select(col=ibis.literal("a")).limit(5).col.median()
result = expr.execute()
assert result == "a"


@pytest.mark.notimpl(
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["snowflake"],
raises=sa.exc.ProgrammingError,
reason="doesn't support median of dates",
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.broken(["polars"], raises=AssertionError, reason="incorrect results")
@pytest.mark.broken(
["pandas"], raises=AssertionError, reason="possibly incorrect results"
)
def test_date_median(alltypes):
expr = alltypes.timestamp_col.date().median()
result = expr.execute()
assert result == date(2009, 12, 31)


@pytest.mark.parametrize(
("result_fn", "expected_fn"),
[
Expand Down
30 changes: 18 additions & 12 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,35 @@ def dtype(self):
return dt.higher_precedence(dtype, dt.float64)


@public
class Median(Filterable, Reduction):
arg: Column[dt.Numeric | dt.Boolean]
class OrderStatistic(Filterable, Reduction):
arg: Column

@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype, dt.float64)
dtype = self.arg.dtype
if dtype.is_numeric():
return dt.higher_precedence(dtype, dt.float64)
else:
return dtype


@public
class Quantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Numeric]
class Median(OrderStatistic):
pass

dtype = dt.float64

@public
class Quantile(OrderStatistic):
quantile: Value[dt.Numeric]


@public
class MultiQuantile(Filterable, Reduction):
arg: Value
quantile: Value[dt.Array[dt.Float64]]
class MultiQuantile(OrderStatistic):
quantile: Value[dt.Array[dt.Numeric]]

dtype = dt.Array(dt.float64)
@attribute
def dtype(self):
return dt.Array(super().dtype)


@public
Expand Down
42 changes: 42 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,48 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
self, key=key, where=self._bind_reduction_filter(where)
).to_expr()

def median(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
Scalar
Median of the column
"""
return ops.Median(self, where=where).to_expr()

def quantile(
self,
quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float],
where: ir.BooleanValue | None = None,
) -> Scalar:
"""Return value at the given quantile.
Parameters
----------
quantile
`0 <= quantile <= 1`, or an array of such values
indicating the quantile or quantiles to compute
where
Boolean filter for input values
Returns
-------
Scalar
Quantile of the input
"""
if isinstance(quantile, Sequence):
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=where).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of distinct rows in an expression.
Expand Down
41 changes: 0 additions & 41 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,47 +755,6 @@ class NumericScalar(Scalar, NumericValue):

@public
class NumericColumn(Column, NumericValue):
def median(self, where: ir.BooleanValue | None = None) -> NumericScalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
NumericScalar
Median of the column
"""
return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr()

def quantile(
self,
quantile: Sequence[NumericValue | float],
where: ir.BooleanValue | None = None,
) -> NumericScalar:
"""Return value at the given quantile.
Parameters
----------
quantile
`0 <= quantile <= 1`, the quantile(s) to compute
where
Boolean filter for input values
Returns
-------
NumericScalar
Quantile of the input
"""
if isinstance(quantile, collections.abc.Sequence):
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr()

def std(
self,
where: ir.BooleanValue | None = None,
Expand Down

0 comments on commit 51119e3

Please sign in to comment.