From 386c5c08c4b33380571bedda69e40623be62a46f Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 16 Aug 2024 14:16:18 -0500 Subject: [PATCH 1/2] refactor: make approximate ops subclasses of their non-approximate variants --- .../clickhouse/tests/test_aggregations.py | 2 +- ibis/expr/operations/reductions.py | 29 +++++++------------ ibis/expr/types/generic.py | 12 ++++---- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/ibis/backends/clickhouse/tests/test_aggregations.py b/ibis/backends/clickhouse/tests/test_aggregations.py index e7c2ce1dfe1b..6f376d263bab 100644 --- a/ibis/backends/clickhouse/tests/test_aggregations.py +++ b/ibis/backends/clickhouse/tests/test_aggregations.py @@ -62,7 +62,7 @@ def test_reduction_invalid_where(alltypes, reduction): ), ( lambda t, cond: t.int_col.approx_median(), - lambda df, cond: np.int32(df.int_col.median()), + lambda df, cond: df.int_col.median(), ), ( lambda t, cond: t.double_col.min(), diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 9161ee563564..443525b624d5 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -210,6 +210,11 @@ class Median(QuantileBase): """Compute the median of a column.""" +@public +class ApproxMedian(Median): + """Compute the approximate median of a column.""" + + @public class Quantile(QuantileBase): """Compute the quantile of a column.""" @@ -325,25 +330,6 @@ class ArgMin(Filterable, Reduction): dtype = rlz.dtype_like("arg") -@public -class ApproxCountDistinct(Filterable, Reduction): - """Approximate number of unique values.""" - - arg: Column - - # Impala 2.0 and higher returns a DOUBLE - dtype = dt.int64 - - -@public -class ApproxMedian(Filterable, Reduction): - """Compute the approximate median of a set of comparable values.""" - - arg: Column - - dtype = rlz.dtype_like("arg") - - @public class GroupConcat(Filterable, Reduction): """Concatenate strings in a group with a given separator character.""" @@ -364,6 +350,11 @@ class CountDistinct(Filterable, Reduction): dtype = dt.int64 +@public +class ApproxCountDistinct(CountDistinct): + """Approximate number of unique values.""" + + @public class ArrayCollect(Filterable, Reduction): """Collect values into an array.""" diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 7a2cfcce3aa5..580c8dce43a6 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1644,13 +1644,13 @@ def approx_median(self, where: ir.BooleanValue | None = None) -> Scalar: >>> ibis.options.interactive = True >>> t = ibis.examples.penguins.fetch() >>> t.body_mass_g.approx_median() - ┌──────┐ - │ 4030 │ - └──────┘ + ┌────────┐ + │ 4030.0 │ + └────────┘ >>> t.body_mass_g.approx_median(where=t.species == "Chinstrap") - ┌──────┐ - │ 3700 │ - └──────┘ + ┌────────┐ + │ 3700.0 │ + └────────┘ """ return ops.ApproxMedian(self, where=self._bind_to_parent_table(where)).to_expr() From 59c6b7841aafbd6a6833c6b18155272ac7e51e25 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 20 Aug 2024 16:37:22 -0500 Subject: [PATCH 2/2] feat(api): add `approx_quantiles` for computing approximate quantiles --- .../test_approx_quantiles/array/out.sql | 7 ++ .../complete-array/out.sql | 3 + .../test_approx_quantiles/scalar/out.sql | 3 + .../shuffled-array/out.sql | 7 ++ .../tricky-scalar/out.sql | 3 + .../bigquery/tests/unit/test_compiler.py | 16 ++++ ibis/backends/polars/compiler.py | 1 + .../sql/compilers/bigquery/__init__.py | 37 +++++++++ ibis/backends/sql/compilers/clickhouse.py | 22 ++++-- ibis/backends/sql/compilers/datafusion.py | 1 + ibis/backends/sql/compilers/duckdb.py | 7 ++ ibis/backends/sql/compilers/exasol.py | 2 + ibis/backends/sql/compilers/mssql.py | 8 ++ ibis/backends/sql/compilers/oracle.py | 9 +++ ibis/backends/sql/compilers/postgres.py | 2 + ibis/backends/sql/compilers/pyspark.py | 9 +++ ibis/backends/sql/compilers/risingwave.py | 15 +++- ibis/backends/sql/compilers/snowflake.py | 6 ++ ibis/backends/sql/compilers/trino.py | 7 ++ ibis/backends/tests/test_aggregation.py | 77 +++++++++++++++---- ibis/expr/operations/reductions.py | 14 ++++ ibis/expr/tests/test_reductions.py | 10 +++ ibis/expr/types/numeric.py | 56 +++++++++++++- 23 files changed, 298 insertions(+), 24 deletions(-) create mode 100644 ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/array/out.sql create mode 100644 ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/complete-array/out.sql create mode 100644 ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/scalar/out.sql create mode 100644 ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/shuffled-array/out.sql create mode 100644 ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/tricky-scalar/out.sql diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/array/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/array/out.sql new file mode 100644 index 000000000000..b56b5c578141 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/array/out.sql @@ -0,0 +1,7 @@ +SELECT + [ + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1], + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2], + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3] + ] AS `qs` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/complete-array/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/complete-array/out.sql new file mode 100644 index 000000000000..a8dc4b0cb6fb --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/complete-array/out.sql @@ -0,0 +1,3 @@ +SELECT + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS) AS `qs` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/scalar/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/scalar/out.sql new file mode 100644 index 000000000000..e8446de495c7 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/scalar/out.sql @@ -0,0 +1,3 @@ +SELECT + approx_quantiles(`t0`.`double_col`, 2 IGNORE NULLS)[1] AS `qs` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/shuffled-array/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/shuffled-array/out.sql new file mode 100644 index 000000000000..3170ecb668cf --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/shuffled-array/out.sql @@ -0,0 +1,7 @@ +SELECT + [ + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2], + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1], + approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3] + ] AS `qs` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/tricky-scalar/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/tricky-scalar/out.sql new file mode 100644 index 000000000000..d84b2a089c03 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_quantiles/tricky-scalar/out.sql @@ -0,0 +1,3 @@ +SELECT + approx_quantiles(`t0`.`double_col`, 100000 IGNORE NULLS)[33333] AS `qs` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/test_compiler.py b/ibis/backends/bigquery/tests/unit/test_compiler.py index 1fd79714c5d9..eb28db2b840f 100644 --- a/ibis/backends/bigquery/tests/unit/test_compiler.py +++ b/ibis/backends/bigquery/tests/unit/test_compiler.py @@ -677,3 +677,19 @@ def test_time_from_hms_with_micros(snapshot): literal = ibis.literal(datetime.time(12, 34, 56)) result = ibis.to_sql(literal, dialect="bigquery") snapshot.assert_match(result, "no_micros.sql") + + +@pytest.mark.parametrize( + "quantiles", + [ + param(0.5, id="scalar"), + param(1 / 3, id="tricky-scalar"), + param([0.25, 0.5, 0.75], id="array"), + param([0.5, 0.25, 0.75], id="shuffled-array"), + param([0, 0.25, 0.5, 0.75, 1], id="complete-array"), + ], +) +def test_approx_quantiles(alltypes, quantiles, snapshot): + query = alltypes.double_col.approx_quantile(quantiles).name("qs") + result = ibis.to_sql(query, dialect="bigquery") + snapshot.assert_match(result, "out.sql") diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index b270774c7627..ac24fbacfd5e 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -785,6 +785,7 @@ def execute_mode(op, **kw): @translate.register(ops.Quantile) +@translate.register(ops.ApproxQuantile) def execute_quantile(op, **kw): arg = translate(op.arg, **kw) quantile = translate(op.quantile, **kw) diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index 003fb876d001..97ccdc1d6511 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -2,6 +2,8 @@ from __future__ import annotations +import decimal +import math import re from typing import TYPE_CHECKING, Any @@ -392,6 +394,41 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return sge.GroupConcat(this=arg, separator=sep) + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not isinstance(op.quantile, ops.Literal): + raise com.UnsupportedOperationError( + "quantile must be a literal in BigQuery" + ) + + # BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return + # `resolution + 1` quantiles array. To handle this, we compute the + # resolution ourselves then restructure the output array as needed. + # To avoid excessive resolution we arbitrarily cap it at 100,000 - + # since these are approximate quantiles anyway this seems fine. + quantiles = util.promote_list(op.quantile.value) + fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles] + resolution = min(math.lcm(*(den for _, den in fracs)), 100_000) + indices = [(num * resolution) // den for num, den in fracs] + + if where is not None: + arg = self.if_(where, arg, NULL) + + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + + array = self.f.approx_quantiles( + arg, sge.IgnoreNulls(this=sge.convert(resolution)) + ) + if isinstance(op, ops.ApproxQuantile): + return array[indices[0]] + + if indices == list(range(resolution + 1)): + return array + else: + return sge.Array(expressions=[array[i] for i in indices]) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + def visit_FloorDivide(self, op, *, left, right): return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 5ea7d7fab26d..f00518e1e078 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -188,18 +188,24 @@ def visit_CountStar(self, op, *, where, arg): return self.f.countIf(where) return sge.Count(this=STAR) - def visit_Quantile(self, op, *, arg, quantile, where): - if where is None: - return self.agg.quantile(arg, quantile, where=where) - - func = "quantile" + "s" * isinstance(op, ops.MultiQuantile) + def _visit_quantile(self, func, arg, quantile, where): return sge.ParameterizedAgg( - this=f"{func}If", + this=f"{func}If" if where is not None else func, expressions=util.promote_list(quantile), - params=[arg, where], + params=[arg, where] if where is not None else [arg], ) - visit_MultiQuantile = visit_Quantile + def visit_Quantile(self, op, *, arg, quantile, where): + return self._visit_quantile("quantile", arg, quantile, where) + + def visit_MultiQuantile(self, op, *, arg, quantile, where): + return self._visit_quantile("quantiles", arg, quantile, where) + + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + return self._visit_quantile("quantileTDigest", arg, quantile, where) + + def visit_ApproxMultiQuantile(self, op, *, arg, quantile, where): + return self._visit_quantile("quantilesTDigest", arg, quantile, where) def visit_Correlation(self, op, *, left, right, how, where): if how == "pop": diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index c3da422ff52b..2bd7d2f0dc4a 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -51,6 +51,7 @@ class DataFusionCompiler(SQLGlotCompiler): ) SIMPLE_OPS = { + ops.ApproxQuantile: "approx_percentile_cont", ops.ApproxMedian: "approx_median", ops.ArrayRemove: "array_remove_all", ops.BitAnd: "bit_and", diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 2ff6e99d41c7..fd47a9121fe9 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -532,6 +532,13 @@ def visit_Quantile(self, op, *, arg, quantile, where): def visit_MultiQuantile(self, op, *, arg, quantile, where): return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where) + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + return self.agg.approx_quantile(arg, quantile, where=where) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + def visit_HexDigest(self, op, *, arg, how): if how in ("md5", "sha256"): return getattr(self.f, how)(arg) diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index 8b3bd7fb713c..38bb88174eb3 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -223,6 +223,8 @@ def visit_Quantile(self, op, *, arg, quantile, where): expression=sge.Order(expressions=[sge.Ordered(this=arg)]), ) + visit_ApproxQuantile = visit_Quantile + def visit_TimestampTruncate(self, op, *, arg, unit): short_name = unit.short unit_mapping = {"W": "IW"} diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 82670d7b1b8e..02cb368ce19e 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -219,6 +219,14 @@ def visit_CountDistinct(self, op, *, arg, where): arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if where is not None: + arg = self.if_(where, arg, NULL) + return sge.WithinGroup( + this=self.f.approx_percentile_cont(quantile), + expression=sge.Order(expressions=[sge.Ordered(this=arg, nulls_first=True)]), + ) + def visit_DayOfWeekIndex(self, op, *, arg): return self.f.datepart(self.v.weekday, arg) - 1 diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 18d8a56a43c0..015c5727994f 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -308,6 +308,15 @@ def visit_Quantile(self, op, *, arg, quantile, where): ) return expr + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if where is not None: + arg = self.if_(where, arg) + + return sge.WithinGroup( + this=self.f.approx_percentile(quantile), + expression=sge.Order(expressions=[sge.Ordered(this=arg)]), + ) + def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg) diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index c3a7fe5f4565..250abeba18dd 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -238,6 +238,8 @@ def visit_Quantile(self, op, *, arg, quantile, where): return expr visit_MultiQuantile = visit_Quantile + visit_ApproxQuantile = visit_Quantile + visit_ApproxMultiQuantile = visit_Quantile def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 7c5eed5c6885..8a1985b17b7b 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -290,6 +290,15 @@ def visit_Quantile(self, op, *, arg, quantile, where): visit_MultiQuantile = visit_Quantile + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + if where is not None: + arg = self.if_(where, arg, NULL) + return self.f.approx_percentile(arg, quantile) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + def visit_Correlation(self, op, *, left, right, how, where): if (left_type := op.left.dtype).is_boolean(): left = self.cast(left, dt.Int32(nullable=left_type.nullable)) diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 5d051b789417..0091cb5408c6 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -6,7 +6,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.sql.compilers import PostgresCompiler -from ibis.backends.sql.compilers.base import ALL_OPERATIONS +from ibis.backends.sql.compilers.base import ALL_OPERATIONS, NULL from ibis.backends.sql.datatypes import RisingWaveType from ibis.backends.sql.dialects import RisingWave @@ -22,6 +22,8 @@ class RisingWaveCompiler(PostgresCompiler): ops.DateFromYMD, ops.Mode, ops.RandomUUID, + ops.MultiQuantile, + ops.ApproxMultiQuantile, *( op for op in ALL_OPERATIONS @@ -65,6 +67,17 @@ def visit_Correlation(self, op, *, left, right, how, where): op, left=left, right=right, how=how, where=where ) + def visit_Quantile(self, op, *, arg, quantile, where): + if where is not None: + arg = self.if_(where, arg, NULL) + suffix = "cont" if op.arg.dtype.is_numeric() else "disc" + return sge.WithinGroup( + this=self.f[f"percentile_{suffix}"](quantile), + expression=sge.Order(expressions=[sge.Ordered(this=arg)]), + ) + + visit_ApproxQuantile = visit_Quantile + def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 0e4d944fddd2..ba404ab4957a 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -608,6 +608,12 @@ def visit_Quantile(self, op, *, arg, quantile, where): quantile = self.f.percentile_cont(quantile) return sge.WithinGroup(this=quantile, expression=order_by) + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if where is not None: + arg = self.if_(where, arg, NULL) + + return self.f.approx_percentile(arg, quantile) + def visit_CountStar(self, op, *, arg, where): if where is None: return super().visit_CountStar(op, arg=arg, where=where) diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 07173f417cc8..93e4b6d2cd9d 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -133,6 +133,13 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + return self.agg.approx_quantile(arg, quantile, where=where) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + def visit_BitXor(self, op, *, arg, where): a, b = map(sg.to_identifier, "ab") input_fn = combine_fn = sge.Lambda( diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index cdaa72d5fb6e..a43043a71168 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -857,7 +857,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): raises=com.UnsupportedBackendType, ), pytest.mark.notyet( - ["snowflake"], + ["snowflake", "risingwave"], reason="backend doesn't implement array of quantiles as input", raises=com.OperationNotDefinedError, ), @@ -876,11 +876,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): reason="backend doesn't implement approximate quantiles yet", raises=com.OperationNotDefinedError, ), - pytest.mark.notimpl( - ["risingwave"], - reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64", - raises=PsycoPg2InternalError, - ), ], ), ], @@ -894,14 +889,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): lambda t: t.string_col.isin(["1", "7"]), id="is_in", marks=[ - pytest.mark.notimpl( - ["datafusion"], raises=com.OperationNotDefinedError - ), - pytest.mark.notimpl( - "risingwave", - raises=PsycoPg2InternalError, - reason="probably incorrect filter syntax but not sure", - ), + pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) ], ), ], @@ -920,6 +908,67 @@ def test_quantile( assert pytest.approx(result) == expected +@pytest.mark.parametrize( + "filtered", + [ + False, + param( + True, + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=Exception, + reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, + ) + ], + ), + ], +) +@pytest.mark.parametrize( + "multi", + [ + False, + param( + True, + marks=[ + pytest.mark.notimpl( + ["datafusion", "oracle", "snowflake", "polars", "risingwave"], + raises=com.OperationNotDefinedError, + reason="multi-quantile not yet implemented", + ), + pytest.mark.notyet( + ["mssql", "exasol"], + raises=com.UnsupportedBackendType, + reason="array types not supported", + ), + ], + ), + ], +) +@pytest.mark.notyet( + ["druid", "flink", "impala", "mysql", "sqlite"], + raises=(com.OperationNotDefinedError, com.UnsupportedBackendType), + reason="quantiles (approximate or otherwise) not supported", +) +def test_approx_quantile(con, filtered, multi): + t = ibis.memtable({"x": [0, 25, 25, 50, 75, 75, 100, 125, 125, 150, 175, 175, 200]}) + where = t.x <= 100 if filtered else None + q = [0.25, 0.75] if multi else 0.25 + res = con.execute(t.x.approx_quantile(q, where=where)) + if multi: + assert isinstance(res, list) + assert all(pd.api.types.is_float(r) for r in res) + sol = [25, 75] if filtered else [50, 150] + else: + assert pd.api.types.is_float(res) + sol = 25 if filtered else 50 + + # Give pretty wide bounds for approximation - we're mostly testing that + # the call is valid and the filtering logic is applied properly. + assert res == pytest.approx(sol, abs=10) + + @pytest.mark.parametrize( ("result_fn", "expected_fn"), [ diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 443525b624d5..6d780e847e21 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -222,6 +222,13 @@ class Quantile(QuantileBase): quantile: Value[dt.Numeric] +@public +class ApproxQuantile(Quantile): + """Compute the approximate quantile of a column.""" + + arg: Column[dt.Numeric] + + @public class MultiQuantile(Filterable, Reduction): """Compute multiple quantiles of a column.""" @@ -237,6 +244,13 @@ def dtype(self): return dt.Array(dtype) +@public +class ApproxMultiQuantile(MultiQuantile): + """Compute multiple approximate quantiles of a column.""" + + arg: Column[dt.Numeric] + + class VarianceBase(Filterable, Reduction): """Base class for variance and standard deviation.""" diff --git a/ibis/expr/tests/test_reductions.py b/ibis/expr/tests/test_reductions.py index 701a79161508..1615a4c8f74e 100644 --- a/ibis/expr/tests/test_reductions.py +++ b/ibis/expr/tests/test_reductions.py @@ -64,6 +64,16 @@ ops.ArrayCollect, id="collect", ), + param( + lambda t, where: t.int_col.approx_quantile(0.5, where=where), + ops.ApproxQuantile, + id="approx_quantile", + ), + param( + lambda t, where: t.int_col.approx_quantile([0.25, 0.5, 0.75], where=where), + ops.ApproxMultiQuantile, + id="approx_multi_quantile", + ), ], ) @pytest.mark.parametrize( diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index aea1e900d39c..33f615eb3dfd 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from collections.abc import Sequence from typing import TYPE_CHECKING, Literal from public import public @@ -13,7 +14,7 @@ from ibis.util import deprecated if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Iterable import ibis.expr.types as ir @@ -1011,6 +1012,59 @@ def histogram( return ((self - base) / binwidth).floor().clip(-1, nbins - 1) + def approx_quantile( + self, + quantile: float | ir.NumericValue | Sequence[ir.NumericValue | float], + where: ir.BooleanValue | None = None, + ) -> NumericScalar: + """Compute one or more approximate quantiles of a column. + + ::: {.callout-note} + ## The result may or may not be exact + + Whether the result is an approximation depends on the backend. + ::: + + Parameters + ---------- + quantile + `0 <= quantile <= 1`, or an array of such values + indicating the quantile or quantiles to compute + where + Boolean filter for input values + + Returns + ------- + Scalar + Quantile of the input + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + + Compute the approximate 0.50 quantile of `bill_depth_mm`. + + >>> t.bill_depth_mm.approx_quantile(0.50) + ┌────────┐ + │ 17.318 │ + └────────┘ + + Compute multiple approximate quantiles in one call - in this case the + result is an array. + + >>> t.bill_depth_mm.approx_quantile([0.25, 0.75]) + ┌────────────────────────┐ + │ [15.565625, 18.671875] │ + └────────────────────────┘ + """ + if isinstance(quantile, Sequence): + op = ops.ApproxMultiQuantile + else: + op = ops.ApproxQuantile + return op(self, quantile, where=self._bind_to_parent_table(where)).to_expr() + @public class IntegerValue(NumericValue):