diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index a83929795007..7b675acd369c 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -489,8 +489,16 @@ def _array_remove(t, op): ops.JSONGetItem: fixed_arity(_json_get_item, 2), ops.RowID: lambda *_: sa.literal_column("rowid"), ops.StringToTimestamp: _strptime, - ops.Quantile: reduction(sa.func.quantile_cont), - ops.MultiQuantile: reduction(sa.func.quantile_cont), + ops.Quantile: lambda t, op: ( + reduction(sa.func.quantile_cont)(t, op) + if op.arg.dtype.is_numeric() + else reduction(sa.func.quantile_disc)(t, op) + ), + ops.MultiQuantile: lambda t, op: ( + reduction(sa.func.quantile_cont)(t, op) + if op.arg.dtype.is_numeric() + else reduction(sa.func.quantile_disc)(t, op) + ), ops.TypeOf: unary(sa.func.typeof), ops.IntervalAdd: fixed_arity(operator.add, 2), ops.IntervalSubtract: fixed_arity(operator.sub, 2), diff --git a/ibis/backends/oracle/registry.py b/ibis/backends/oracle/registry.py index a502fb6963f0..8c6b074bd21d 100644 --- a/ibis/backends/oracle/registry.py +++ b/ibis/backends/oracle/registry.py @@ -81,6 +81,19 @@ def _string_join(t, op): return sa.func.concat(*toolz.interpose(sep, values)) +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + return sa.func.median(t.translate(arg)) + return sa.cast( + sa.func.percentile_disc(0.5).within_group(t.translate(arg)), + t.get_sqla_type(op.dtype), + ) + + operation_registry.update( { ops.Log2: unary(lambda arg: sa.func.log(2, arg)), @@ -96,7 +109,7 @@ def _string_join(t, op): ops.Covariance: _cov, ops.Correlation: _corr, ops.ApproxMedian: reduction(sa.func.approx_median), - ops.Median: reduction(sa.func.median), + ops.Median: _median, # Temporal ops.ExtractSecond: _second, # String diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 16b3d3cc5722..3d2168e613b5 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -673,6 +673,7 @@ def struct_column(op, **kw): ops.Mean: "mean", ops.Median: "median", ops.Min: "min", + ops.Mode: "mode", ops.StandardDev: "std", ops.Sum: "sum", ops.Variance: "var", @@ -682,22 +683,36 @@ def struct_column(op, **kw): @translate.register(reduction) def reduction(op, **kw): - arg = translate(op.arg, **kw) + args = [ + translate(arg, **kw) + for name, arg in zip(op.argnames, op.args) + if name not in ("where", "how") + ] + agg = _reductions[type(op)] - filt = arg.is_not_null() + + predicates = [arg.is_not_null() for arg in args] if (where := op.where) is not None: - filt &= translate(where, **kw) - arg = arg.filter(filt) - method = getattr(arg, agg) - return method().cast(dtype_to_polars(op.dtype)) + predicates.append(translate(where, **kw)) + first, *rest = args + method = operator.methodcaller(agg, *rest) + return method(first.filter(reduce(operator.and_, predicates))).cast( + dtype_to_polars(op.dtype) + ) -@translate.register(ops.Mode) -def mode(op, **kw): + +@translate.register(ops.Quantile) +def execute_quantile(op, **kw): arg = translate(op.arg, **kw) + quantile = translate(op.quantile, **kw) + filt = arg.is_not_null() & quantile.is_not_null() if (where := op.where) is not None: - arg = arg.filter(translate(where, **kw)) - return arg.mode().min() + filt &= translate(where, **kw) + + # we can't throw quantile into the _reductions mapping because Polars' + # default interpolation of "nearest" doesn't match the rest of our backends + return arg.filter(filt).quantile(quantile, interpolation="linear") @translate.register(ops.Correlation) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 6bd3070422b3..bbca6e9c20db 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -457,9 +457,11 @@ def _quantile(t, op): arg = op.arg if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(t.translate(op.quantile)).within_group( - t.translate(arg) - ) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) def _median(t, op): @@ -467,7 +469,11 @@ def _median(t, op): if (where := op.where) is not None: arg = ops.IfElse(where, arg, None) - return sa.func.percentile_cont(0.5).within_group(t.translate(arg)) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) def _binary_variance_reduction(func): diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 942567316d89..dbbd31533419 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1,5 +1,8 @@ from __future__ import annotations +from datetime import date +from operator import methodcaller + import numpy as np import pandas as pd import pytest @@ -23,11 +26,6 @@ except ImportError: GoogleBadRequest = None -try: - from polars.exceptions import ComputeError -except ImportError: - ComputeError = None - try: from clickhouse_connect.driver.exceptions import ( DatabaseError as ClickhouseDatabaseError, @@ -40,12 +38,16 @@ except ImportError: Py4JError = None - try: from pyexasol.exceptions import ExaQueryError except ImportError: ExaQueryError = None +try: + from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError +except ImportError: + PolarsInvalidOperationError = None + @reduction(input_type=[dt.double], output_type=dt.double) def mean_udf(s): @@ -899,7 +901,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): "impala", "mssql", "mysql", - "polars", "sqlite", "druid", "oracle", @@ -1210,6 +1211,102 @@ def test_median(alltypes, df): assert result == expected +@pytest.mark.notimpl( + ["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "trino", "exasol", "flink"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet( + ["clickhouse"], + raises=ClickhouseDatabaseError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet( + ["oracle"], raises=sa.exc.DatabaseError, reason="doesn't support median of strings" +) +@pytest.mark.broken( + ["pyspark"], raises=AssertionError, reason="pyspark returns null for string median" +) +@pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError)) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of strings", +) +@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream") +@pytest.mark.notimpl( + ["pandas"], raises=TypeError, reason="results aren't correctly typed" +) +@pytest.mark.parametrize( + "func", + [ + param( + methodcaller("quantile", 0.5), + id="quantile", + marks=[ + pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) + ], + ), + param( + methodcaller("median"), + id="median", + marks=[ + pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError) + ], + ), + ], +) +def test_string_quantile(alltypes, func): + expr = func(alltypes.select(col=ibis.literal("a")).limit(5).col) + result = expr.execute() + assert result == "a" + + +@pytest.mark.notimpl(["bigquery", "sqlite"], raises=com.OperationNotDefinedError) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "trino", "exasol", "flink"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.broken(["druid"], raises=AttributeError) +@pytest.mark.notyet( + ["snowflake"], + raises=sa.exc.ProgrammingError, + reason="doesn't support median of dates", +) +@pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError)) +@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream") +@pytest.mark.broken( + ["pandas"], raises=AssertionError, reason="possibly incorrect results" +) +@pytest.mark.parametrize( + "func", + [ + param( + methodcaller("quantile", 0.5), + id="quantile", + marks=[ + pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) + ], + ), + param( + methodcaller("median"), + id="median", + marks=[ + pytest.mark.notimpl(["pyspark"], raises=com.OperationNotDefinedError) + ], + ), + ], +) +def test_date_quantile(alltypes, func): + expr = func(alltypes.timestamp_col.date()) + result = expr.execute() + assert result == date(2009, 12, 31) + + @pytest.mark.parametrize( ("result_fn", "expected_fn"), [ diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 77b4435fd586..e0b63041395e 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -187,29 +187,38 @@ def dtype(self): return dt.higher_precedence(dtype, dt.float64) -@public -class Median(Filterable, Reduction): - arg: Column[dt.Numeric | dt.Boolean] +class QuantileBase(Filterable, Reduction): + arg: Column @attribute def dtype(self): - return dt.higher_precedence(self.arg.dtype, dt.float64) + dtype = self.arg.dtype + if dtype.is_numeric(): + dtype = dt.higher_precedence(dtype, dt.float64) + return dtype @public -class Quantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Numeric] +class Median(QuantileBase): + pass - dtype = dt.float64 + +@public +class Quantile(QuantileBase): + quantile: Value[dt.Numeric] @public class MultiQuantile(Filterable, Reduction): - arg: Value - quantile: Value[dt.Array[dt.Float64]] + arg: Column + quantile: Value[dt.Array[dt.Numeric]] - dtype = dt.Array(dt.float64) + @attribute + def dtype(self): + dtype = self.arg.dtype + if dtype.is_numeric(): + dtype = dt.higher_precedence(dtype, dt.float64) + return dt.Array(dtype) @public diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 4959141de202..70a6c62eaa4e 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1571,6 +1571,131 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: self, key=key, where=self._bind_reduction_filter(where) ).to_expr() + def median(self, where: ir.BooleanValue | None = None) -> Scalar: + """Return the median of the column. + + Parameters + ---------- + where + Optional boolean expression. If given, only the values where + `where` evaluates to true will be considered for the median. + + Returns + ------- + Scalar + Median of the column + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + + Compute the median of `bill_depth_mm` + + >>> t.bill_depth_mm.median() + 17.3 + >>> t.group_by(t.species).agg( + ... median_bill_depth=t.bill_depth_mm.median() + ... ).order_by(ibis.desc("median_bill_depth")) + ┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ + ┃ species ┃ median_bill_depth ┃ + ┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ + │ string │ float64 │ + ├───────────┼───────────────────┤ + │ Chinstrap │ 18.45 │ + │ Adelie │ 18.40 │ + │ Gentoo │ 15.00 │ + └───────────┴───────────────────┘ + + In addition to numeric types, any orderable non-numeric types such as + strings and dates work with `median`. + + >>> t.group_by(t.island).agg( + ... median_species=t.species.median() + ... ).order_by(ibis.desc("median_species")) + ┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓ + ┃ island ┃ median_species ┃ + ┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩ + │ string │ string │ + ├───────────┼────────────────┤ + │ Biscoe │ Gentoo │ + │ Dream │ Chinstrap │ + │ Torgersen │ Adelie │ + └───────────┴────────────────┘ + """ + return ops.Median(self, where=where).to_expr() + + def quantile( + self, + quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float], + where: ir.BooleanValue | None = None, + ) -> Scalar: + """Return value at the given quantile. + + The output of this method is a continuous quantile if the input is + numeric, otherwise the output is a discrete quantile. + + Parameters + ---------- + quantile + `0 <= quantile <= 1`, or an array of such values + indicating the quantile or quantiles to compute + where + Boolean filter for input values + + Returns + ------- + Scalar + Quantile of the input + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + + Compute the 99th percentile of `bill_depth` + + >>> t.bill_depth_mm.quantile(0.99) + 21.1 + >>> t.group_by(t.species).agg( + ... p99_bill_depth=t.bill_depth_mm.quantile(0.99) + ... ).order_by(ibis.desc("p99_bill_depth")) + ┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓ + ┃ species ┃ p99_bill_depth ┃ + ┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩ + │ string │ float64 │ + ├───────────┼────────────────┤ + │ Adelie │ 21.200 │ + │ Chinstrap │ 20.733 │ + │ Gentoo │ 17.256 │ + └───────────┴────────────────┘ + + In addition to numeric types, any orderable non-numeric types such as + strings and dates work with `quantile`. + + Let's compute the 99th percentile of the `species` column + + >>> t.group_by(t.island).agg( + ... p99_species=t.species.quantile(0.99) + ... ).order_by(ibis.desc("p99_species")) + ┏━━━━━━━━━━━┳━━━━━━━━━━━━━┓ + ┃ island ┃ p99_species ┃ + ┡━━━━━━━━━━━╇━━━━━━━━━━━━━┩ + │ string │ string │ + ├───────────┼─────────────┤ + │ Biscoe │ Gentoo │ + │ Dream │ Chinstrap │ + │ Torgersen │ Adelie │ + └───────────┴─────────────┘ + """ + if isinstance(quantile, Sequence): + op = ops.MultiQuantile + else: + op = ops.Quantile + return op(self, quantile, where=where).to_expr() + def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of distinct rows in an expression. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 479823649d4e..0cd7764e99e4 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -755,47 +755,6 @@ class NumericScalar(Scalar, NumericValue): @public class NumericColumn(Column, NumericValue): - def median(self, where: ir.BooleanValue | None = None) -> NumericScalar: - """Return the median of the column. - - Parameters - ---------- - where - Optional boolean expression. If given, only the values where - `where` evaluates to true will be considered for the median. - - Returns - ------- - NumericScalar - Median of the column - """ - return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr() - - def quantile( - self, - quantile: Sequence[NumericValue | float], - where: ir.BooleanValue | None = None, - ) -> NumericScalar: - """Return value at the given quantile. - - Parameters - ---------- - quantile - `0 <= quantile <= 1`, the quantile(s) to compute - where - Boolean filter for input values - - Returns - ------- - NumericScalar - Quantile of the input - """ - if isinstance(quantile, collections.abc.Sequence): - op = ops.MultiQuantile - else: - op = ops.Quantile - return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr() - def std( self, where: ir.BooleanValue | None = None,