From c53031c9069ea9f1e121a9ce3f139ba69eee7fdb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 25 Apr 2023 09:47:16 -0400 Subject: [PATCH] feat(api): exact median --- ibis/backends/clickhouse/compiler/values.py | 1 + ibis/backends/duckdb/registry.py | 1 + ibis/backends/polars/compiler.py | 1 + ibis/backends/postgres/registry.py | 5 +- ibis/backends/snowflake/registry.py | 1 + ibis/backends/tests/test_aggregation.py | 52 ++++++++++++++------- ibis/expr/operations/reductions.py | 9 ++++ ibis/expr/types/numeric.py | 16 +++++++ 8 files changed, 68 insertions(+), 18 deletions(-) diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 27f5404f1ef0..27b8072648e6 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -1016,6 +1016,7 @@ def formatter(op, **kw): ops.RandomScalar: "randCanonical", # Unary aggregates ops.ApproxMedian: "median", + ops.Median: "quantileExactExclusive", # TODO: there is also a `uniq` function which is the # recommended way to approximate cardinality ops.ApproxCountDistinct: "uniqHLL12", diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 606450528ed2..77dd98758e0f 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -418,6 +418,7 @@ def _map_merge(t, op): ops.MapValues: _map_values, ops.MapMerge: _map_merge, ops.Hash: unary(sa.func.hash), + ops.Median: reduction(sa.func.median), } ) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index f1bf9b1a9cc7..7dada4c3a81e 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -642,6 +642,7 @@ def struct_column(op): ops.Sum: 'sum', ops.Variance: 'var', ops.CountDistinct: 'n_unique', + ops.Median: 'median', } for reduction in _reductions.keys(): diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 0dcfb3a3acc5..a29b9cdff0b1 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -444,7 +444,7 @@ def _quantile(t, op): ) -def _approx_median(t, op): +def _median(t, op): arg = op.arg if (where := op.where) is not None: arg = ops.Where(where, arg, None) @@ -667,7 +667,8 @@ def _unnest(t, op): ops.Correlation: _corr, ops.BitwiseXor: _bitwise_op("#"), ops.Mode: _mode, - ops.ApproxMedian: _approx_median, + ops.ApproxMedian: _median, + ops.Median: _median, ops.Quantile: _quantile, ops.MultiQuantile: _quantile, ops.TimestampNow: lambda t, op: sa.literal_column( diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 8a864e76c766..cdca02514106 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -390,6 +390,7 @@ def _group_concat(t, op): ops.GroupConcat: _group_concat, ops.Hash: unary(sa.func.hash), ops.ApproxMedian: reduction(lambda x: sa.func.approx_percentile(x, 0.5)), + ops.Median: reduction(sa.func.median), } ) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index fe9f631c7a16..10b54d251558 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -108,8 +108,8 @@ def mean_udf(s): "datafusion", "impala", "mysql", - "pyspark", "mssql", + "pyspark", "trino", "druid", ], @@ -437,20 +437,25 @@ def mean_and_std(v): lambda t, where: (t.int_col % 3).mode(where=where), lambda t, where: (t.int_col % 3)[where].mode().iloc[0], id='mode', - marks=pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "datafusion", - "impala", - "mysql", - "pyspark", - "mssql", - "trino", - "druid", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notyet( + [ + "bigquery", + "clickhouse", + "datafusion", + "impala", + "mysql", + "pyspark", + "mssql", + "trino", + "druid", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.xfail_version( + pyspark=["pyspark<3.4.0"], raises=AttributeError + ), + ], ), param( lambda t, where: t.double_col.argmin(t.int_col, where=where), @@ -1031,7 +1036,7 @@ def test_corr_cov( @pytest.mark.broken( ["dask", "pandas"], raises=AttributeError, - reason="'Series' object has no attribute 'approxmedian'", + reason="'Series' object has no attribute 'approx_median'", ) def test_approx_median(alltypes): expr = alltypes.double_col.approx_median() @@ -1039,6 +1044,21 @@ def test_approx_median(alltypes): assert isinstance(result, float) +@pytest.mark.notimpl( + ["datafusion", "druid", "sqlite"], raises=com.OperationNotDefinedError +) +@pytest.mark.notyet( + ["impala", "mysql", "mssql", "druid", "pyspark", "trino"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notyet(["dask"], raises=NotImplementedError) +def test_median(alltypes, df): + expr = alltypes.double_col.median() + result = expr.execute() + expected = df.double_col.median() + assert result == expected + + @mark.parametrize( ('result_fn', 'expected_fn'), [ diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 426cd0f74ca9..01c9a1f77137 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -119,6 +119,15 @@ def output_dtype(self): return dt.higher_precedence(dtype, dt.float64) +@public +class Median(Filterable, Reduction): + arg = rlz.column(rlz.numeric) + + @attribute.default + def output_dtype(self): + return dt.higher_precedence(self.arg.output_dtype, dt.float64) + + @public class Quantile(Filterable, Reduction): arg = rlz.any diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 7ac155acd1f1..f698e82d9bc7 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -313,6 +313,22 @@ class NumericScalar(Scalar, NumericValue): @public class NumericColumn(Column, NumericValue): + def median(self, where: ir.BooleanValue | None = None) -> NumericScalar: + """Return the median of the column. + + Parameters + ---------- + where + Optional boolean expression. If given, only the values where + `where` evaluates to true will be considered for the median. + + Returns + ------- + NumericScalar + Median of the column + """ + return ops.Median(self, where=where).to_expr() + def quantile( self, quantile: Sequence[NumericValue | float],