From 7fa987c600227823f6e6f236b1bc03f806d553ff Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 29 Sep 2023 05:24:33 -0400 Subject: [PATCH] feat(api): support order statistics on more types --- ibis/backends/duckdb/registry.py | 4 +- ibis/backends/oracle/registry.py | 15 +++++- ibis/backends/postgres/registry.py | 14 ++++-- ibis/backends/tests/test_aggregation.py | 64 ++++++++++++++++++++++--- ibis/expr/operations/reductions.py | 30 +++++++----- ibis/expr/types/generic.py | 42 ++++++++++++++++ ibis/expr/types/numeric.py | 41 ---------------- 7 files changed, 144 insertions(+), 66 deletions(-) diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index a839297950073..f6d028a60421f 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -489,8 +489,8 @@ def _array_remove(t, op): ops.JSONGetItem: fixed_arity(_json_get_item, 2), ops.RowID: lambda *_: sa.literal_column("rowid"), ops.StringToTimestamp: _strptime, - ops.Quantile: reduction(sa.func.quantile_cont), - ops.MultiQuantile: reduction(sa.func.quantile_cont), + ops.Quantile: reduction(sa.func.quantile), + ops.MultiQuantile: reduction(sa.func.quantile), ops.TypeOf: unary(sa.func.typeof), ops.IntervalAdd: fixed_arity(operator.add, 2), ops.IntervalSubtract: fixed_arity(operator.sub, 2), diff --git a/ibis/backends/oracle/registry.py b/ibis/backends/oracle/registry.py index a502fb6963f03..8c6b074bd21d0 100644 --- a/ibis/backends/oracle/registry.py +++ b/ibis/backends/oracle/registry.py @@ -81,6 +81,19 @@ def _string_join(t, op): return sa.func.concat(*toolz.interpose(sep, values)) +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + return sa.func.median(t.translate(arg)) + return sa.cast( + sa.func.percentile_disc(0.5).within_group(t.translate(arg)), + t.get_sqla_type(op.dtype), + ) + + operation_registry.update( { ops.Log2: unary(lambda arg: sa.func.log(2, arg)), @@ -96,7 +109,7 @@ def _string_join(t, op): ops.Covariance: _cov, ops.Correlation: _corr, ops.ApproxMedian: reduction(sa.func.approx_median), - ops.Median: reduction(sa.func.median), + ops.Median: _median, # Temporal ops.ExtractSecond: _second, # String diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 6bd3070422b37..bbca6e9c20db9 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -457,9 +457,11 @@ def _quantile(t, op): arg = op.arg if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(t.translate(op.quantile)).within_group( - t.translate(arg) - ) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) def _median(t, op): @@ -467,7 +469,11 @@ def _median(t, op): if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(0.5).within_group(t.translate(arg)) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) def _binary_variance_reduction(func): diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 942567316d890..691f4b4c7f5bc 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import date + import numpy as np import pandas as pd import pytest @@ -23,11 +25,6 @@ except ImportError: GoogleBadRequest = None -try: - from polars.exceptions import ComputeError -except ImportError: - ComputeError = None - try: from clickhouse_connect.driver.exceptions import ( DatabaseError as ClickhouseDatabaseError, @@ -40,7 +37,6 @@ except ImportError: Py4JError = None - try: from pyexasol.exceptions import ExaQueryError except ImportError: @@ -1210,6 +1206,62 @@ def test_median(alltypes, df): assert result == expected +@pytest.mark.notimpl( + ["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "druid", "pyspark", "trino"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet( + ["clickhouse"], + raises=ClickhouseDatabaseError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet( + ["oracle"], + raises=sa.exc.DatabaseError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet(["dask"], raises=NotImplementedError) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of strings", +) +@pytest.mark.broken(["polars"], raises=AssertionError, reason="incorrect results") +@pytest.mark.notimpl( + ["pandas"], raises=TypeError, reason="results aren't correctly typed" +) +def test_string_median(alltypes): + expr = alltypes.select(col=ibis.literal("a")).limit(5).col.median() + result = expr.execute() + assert result == "a" + + +@pytest.mark.notimpl( + ["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "druid", "pyspark", "trino"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of dates", +) +@pytest.mark.notyet(["dask"], raises=NotImplementedError) +@pytest.mark.broken(["polars"], raises=AssertionError, reason="incorrect results") +@pytest.mark.broken( + ["pandas"], raises=AssertionError, reason="possibly incorrect results" +) +def test_date_median(alltypes): + expr = alltypes.timestamp_col.date().median() + result = expr.execute() + assert result == date(2009, 12, 31) + + @pytest.mark.parametrize( ("result_fn", "expected_fn"), [ diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 77b4435fd5861..a4fcdcc76548e 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -187,29 +187,35 @@ def dtype(self): return dt.higher_precedence(dtype, dt.float64) -@public -class Median(Filterable, Reduction): - arg: Column[dt.Numeric | dt.Boolean] +class OrderStatistic(Filterable, Reduction): + arg: Column @attribute def dtype(self): - return dt.higher_precedence(self.arg.dtype, dt.float64) + dtype = self.arg.dtype + if dtype.is_numeric(): + return dt.higher_precedence(dtype, dt.float64) + else: + return dtype @public -class Quantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Numeric] +class Median(OrderStatistic): + pass - dtype = dt.float64 + +@public +class Quantile(OrderStatistic): + quantile: Value[dt.Numeric] @public -class MultiQuantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Array[dt.Float64]] +class MultiQuantile(OrderStatistic): + quantile: Value[dt.Array[dt.Numeric]] - dtype = dt.Array(dt.float64) + @attribute + def dtype(self): + return dt.Array(super().dtype) @public diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 4959141de2028..d03ed2029c1f3 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1571,6 +1571,48 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: self, key=key, where=self._bind_reduction_filter(where) ).to_expr() + def median(self, where: ir.BooleanValue | None = None) -> Scalar: + """Return the median of the column. + + Parameters + ---------- + where + Optional boolean expression. If given, only the values where + `where` evaluates to true will be considered for the median. + + Returns + ------- + Scalar + Median of the column + """ + return ops.Median(self, where=where).to_expr() + + def quantile( + self, + quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float], + where: ir.BooleanValue | None = None, + ) -> Scalar: + """Return value at the given quantile. + + Parameters + ---------- + quantile + `0 <= quantile <= 1`, or an array of such values + indicating the quantile or quantiles to compute + where + Boolean filter for input values + + Returns + ------- + Scalar + Quantile of the input + """ + if isinstance(quantile, Sequence): + op = ops.MultiQuantile + else: + op = ops.Quantile + return op(self, quantile, where=where).to_expr() + def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of distinct rows in an expression. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 479823649d4e4..0cd7764e99e4e 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -755,47 +755,6 @@ class NumericScalar(Scalar, NumericValue): @public class NumericColumn(Column, NumericValue): - def median(self, where: ir.BooleanValue | None = None) -> NumericScalar: - """Return the median of the column. - - Parameters - ---------- - where - Optional boolean expression. If given, only the values where - `where` evaluates to true will be considered for the median. - - Returns - ------- - NumericScalar - Median of the column - """ - return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr() - - def quantile( - self, - quantile: Sequence[NumericValue | float], - where: ir.BooleanValue | None = None, - ) -> NumericScalar: - """Return value at the given quantile. - - Parameters - ---------- - quantile - `0 <= quantile <= 1`, the quantile(s) to compute - where - Boolean filter for input values - - Returns - ------- - NumericScalar - Quantile of the input - """ - if isinstance(quantile, collections.abc.Sequence): - op = ops.MultiQuantile - else: - op = ops.Quantile - return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr() - def std( self, where: ir.BooleanValue | None = None,