diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index a83929795007..7b675acd369c 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -489,8 +489,16 @@ def _array_remove(t, op): ops.JSONGetItem: fixed_arity(_json_get_item, 2), ops.RowID: lambda *_: sa.literal_column("rowid"), ops.StringToTimestamp: _strptime, - ops.Quantile: reduction(sa.func.quantile_cont), - ops.MultiQuantile: reduction(sa.func.quantile_cont), + ops.Quantile: lambda t, op: ( + reduction(sa.func.quantile_cont)(t, op) + if op.arg.dtype.is_numeric() + else reduction(sa.func.quantile_disc)(t, op) + ), + ops.MultiQuantile: lambda t, op: ( + reduction(sa.func.quantile_cont)(t, op) + if op.arg.dtype.is_numeric() + else reduction(sa.func.quantile_disc)(t, op) + ), ops.TypeOf: unary(sa.func.typeof), ops.IntervalAdd: fixed_arity(operator.add, 2), ops.IntervalSubtract: fixed_arity(operator.sub, 2), diff --git a/ibis/backends/oracle/registry.py b/ibis/backends/oracle/registry.py index a502fb6963f0..8c6b074bd21d 100644 --- a/ibis/backends/oracle/registry.py +++ b/ibis/backends/oracle/registry.py @@ -81,6 +81,19 @@ def _string_join(t, op): return sa.func.concat(*toolz.interpose(sep, values)) +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + return sa.func.median(t.translate(arg)) + return sa.cast( + sa.func.percentile_disc(0.5).within_group(t.translate(arg)), + t.get_sqla_type(op.dtype), + ) + + operation_registry.update( { ops.Log2: unary(lambda arg: sa.func.log(2, arg)), @@ -96,7 +109,7 @@ def _string_join(t, op): ops.Covariance: _cov, ops.Correlation: _corr, ops.ApproxMedian: reduction(sa.func.approx_median), - ops.Median: reduction(sa.func.median), + ops.Median: _median, # Temporal ops.ExtractSecond: _second, # String diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 6bd3070422b3..bbca6e9c20db 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -457,9 +457,11 @@ def _quantile(t, op): arg = op.arg if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(t.translate(op.quantile)).within_group( - t.translate(arg) - ) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) def _median(t, op): @@ -467,7 +469,11 @@ def _median(t, op): if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(0.5).within_group(t.translate(arg)) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) def _binary_variance_reduction(func): diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 942567316d89..120ba272332f 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1,5 +1,8 @@ from __future__ import annotations +from datetime import date +from operator import methodcaller + import numpy as np import pandas as pd import pytest @@ -23,11 +26,6 @@ except ImportError: GoogleBadRequest = None -try: - from polars.exceptions import ComputeError -except ImportError: - ComputeError = None - try: from clickhouse_connect.driver.exceptions import ( DatabaseError as ClickhouseDatabaseError, @@ -40,12 +38,16 @@ except ImportError: Py4JError = None - try: from pyexasol.exceptions import ExaQueryError except ImportError: ExaQueryError = None +try: + from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError +except ImportError: + PolarsInvalidOperationError = None + @reduction(input_type=[dt.double], output_type=dt.double) def mean_udf(s): @@ -1210,6 +1212,70 @@ def test_median(alltypes, df): assert result == expected +@pytest.mark.notimpl( + ["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "druid", "pyspark", "trino"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet( + ["clickhouse"], + raises=ClickhouseDatabaseError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet( + ["oracle"], + raises=sa.exc.DatabaseError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet(["dask"], raises=NotImplementedError) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream") +@pytest.mark.notimpl( + ["pandas"], raises=TypeError, reason="results aren't correctly typed" +) +@pytest.mark.parametrize( + "func", [methodcaller("quantile", 0.5), methodcaller("median")] +) +def test_string_quantile(alltypes, func): + expr = func(alltypes.select(col=ibis.literal("a")).limit(5).col) + result = expr.execute() + assert result == "a" + + +@pytest.mark.notimpl( + ["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "druid", "pyspark", "trino"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of dates", +) +@pytest.mark.notyet(["dask"], raises=NotImplementedError) +@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream") +@pytest.mark.broken( + ["pandas"], raises=AssertionError, reason="possibly incorrect results" +) +@pytest.mark.parametrize( + "func", [methodcaller("quantile", 0.5), methodcaller("median")] +) +def test_date_quantile(alltypes, func): + expr = func(alltypes.timestamp_col.date()) + result = expr.execute() + assert result == date(2009, 12, 31) + + @pytest.mark.parametrize( ("result_fn", "expected_fn"), [ diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 77b4435fd586..e0b63041395e 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -187,29 +187,38 @@ def dtype(self): return dt.higher_precedence(dtype, dt.float64) -@public -class Median(Filterable, Reduction): - arg: Column[dt.Numeric | dt.Boolean] +class QuantileBase(Filterable, Reduction): + arg: Column @attribute def dtype(self): - return dt.higher_precedence(self.arg.dtype, dt.float64) + dtype = self.arg.dtype + if dtype.is_numeric(): + dtype = dt.higher_precedence(dtype, dt.float64) + return dtype @public -class Quantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Numeric] +class Median(QuantileBase): + pass - dtype = dt.float64 + +@public +class Quantile(QuantileBase): + quantile: Value[dt.Numeric] @public class MultiQuantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Array[dt.Float64]] + arg: Column + quantile: Value[dt.Array[dt.Numeric]] - dtype = dt.Array(dt.float64) + @attribute + def dtype(self): + dtype = self.arg.dtype + if dtype.is_numeric(): + dtype = dt.higher_precedence(dtype, dt.float64) + return dt.Array(dtype) @public diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 4959141de202..d03ed2029c1f 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1571,6 +1571,48 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: self, key=key, where=self._bind_reduction_filter(where) ).to_expr() + def median(self, where: ir.BooleanValue | None = None) -> Scalar: + """Return the median of the column. + + Parameters + ---------- + where + Optional boolean expression. If given, only the values where + `where` evaluates to true will be considered for the median. + + Returns + ------- + Scalar + Median of the column + """ + return ops.Median(self, where=where).to_expr() + + def quantile( + self, + quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float], + where: ir.BooleanValue | None = None, + ) -> Scalar: + """Return value at the given quantile. + + Parameters + ---------- + quantile + `0 <= quantile <= 1`, or an array of such values + indicating the quantile or quantiles to compute + where + Boolean filter for input values + + Returns + ------- + Scalar + Quantile of the input + """ + if isinstance(quantile, Sequence): + op = ops.MultiQuantile + else: + op = ops.Quantile + return op(self, quantile, where=where).to_expr() + def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of distinct rows in an expression. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 479823649d4e..0cd7764e99e4 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -755,47 +755,6 @@ class NumericScalar(Scalar, NumericValue): @public class NumericColumn(Column, NumericValue): - def median(self, where: ir.BooleanValue | None = None) -> NumericScalar: - """Return the median of the column. - - Parameters - ---------- - where - Optional boolean expression. If given, only the values where - `where` evaluates to true will be considered for the median. - - Returns - ------- - NumericScalar - Median of the column - """ - return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr() - - def quantile( - self, - quantile: Sequence[NumericValue | float], - where: ir.BooleanValue | None = None, - ) -> NumericScalar: - """Return value at the given quantile. - - Parameters - ---------- - quantile - `0 <= quantile <= 1`, the quantile(s) to compute - where - Boolean filter for input values - - Returns - ------- - NumericScalar - Quantile of the input - """ - if isinstance(quantile, collections.abc.Sequence): - op = ops.MultiQuantile - else: - op = ops.Quantile - return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr() - def std( self, where: ir.BooleanValue | None = None,