From 51335edfa78bcd72d6ceacd24137cef6d253bd64 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 26 Sep 2024 05:50:58 -0500 Subject: [PATCH] fix(sql): standardize NULL handling of `argmin`/`argmax` (#10227) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- ibis/backends/polars/compiler.py | 19 ++++------ ibis/backends/sql/compilers/base.py | 2 -- .../sql/compilers/bigquery/__init__.py | 2 ++ ibis/backends/sql/compilers/clickhouse.py | 14 ++++++-- ibis/backends/sql/compilers/datafusion.py | 10 ++++-- ibis/backends/sql/compilers/druid.py | 2 -- ibis/backends/sql/compilers/duckdb.py | 8 +++++ ibis/backends/sql/compilers/exasol.py | 2 -- ibis/backends/sql/compilers/flink.py | 2 -- ibis/backends/sql/compilers/impala.py | 2 -- ibis/backends/sql/compilers/mssql.py | 2 -- ibis/backends/sql/compilers/mysql.py | 2 -- ibis/backends/sql/compilers/oracle.py | 2 -- ibis/backends/sql/compilers/postgres.py | 14 ++++---- ibis/backends/sql/compilers/pyspark.py | 2 ++ ibis/backends/sql/compilers/snowflake.py | 2 ++ ibis/backends/sql/compilers/sqlite.py | 7 +--- ibis/backends/sql/compilers/trino.py | 2 ++ ibis/backends/tests/test_aggregation.py | 36 +++++++++++++++++-- ibis/expr/types/generic.py | 6 ++++ 20 files changed, 89 insertions(+), 49 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 983256b518ce..f968ba4c272f 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1256,20 +1256,15 @@ def execute_hash(op, **kw): def _arg_min_max(op, func, **kw): - key = op.key - arg = op.arg - - if (op_where := op.where) is not None: - key = ops.IfElse(op_where, key, None) - arg = ops.IfElse(op_where, arg, None) + key = translate(op.key, **kw) + arg = translate(op.arg, **kw) - translate_arg = translate(arg, **kw) - translate_key = translate(key, **kw) + if op.where is not None: + where = translate(op.where, **kw) + arg = arg.filter(where) + key = key.filter(where) - not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null() - return translate_arg.filter(not_null_mask).get( - func(translate_key.filter(not_null_mask)) - ) + return arg.get(func(key)) @translate.register(ops.ArgMax) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 921a568d3041..a1becd5bfdab 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -306,8 +306,6 @@ class SQLGlotCompiler(abc.ABC): ops.All: "bool_and", ops.Any: "bool_or", ops.ApproxCountDistinct: "approx_distinct", - ops.ArgMax: "max_by", - ops.ArgMin: "min_by", ops.ArrayContains: "array_contains", ops.ArrayFlatten: "flatten", ops.ArrayLength: "array_size", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index 47ee22cde24a..190fa28ba8f0 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -200,6 +200,8 @@ class BigQueryCompiler(SQLGlotCompiler): ops.TimeFromHMS: "time_from_parts", ops.TimestampNow: "current_timestamp", ops.ExtractHost: "net.host", + ops.ArgMin: "min_by", + ops.ArgMax: "max_by", } def to_sqlglot( diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 185b44052cdd..1c683b511514 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -62,8 +62,6 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.ApproxCountDistinct: "uniqHLL12", ops.ApproxMedian: "median", ops.Arbitrary: "any", - ops.ArgMax: "argMax", - ops.ArgMin: "argMin", ops.ArrayContains: "has", ops.ArrayFlatten: "arrayFlatten", ops.ArrayIntersect: "arrayIntersect", @@ -673,6 +671,18 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): ) return self.agg.anyLast(arg, where=where, order_by=order_by) + def visit_ArgMin(self, op, *, arg, key, where): + return sge.Dot( + this=self.agg.argMin(self.f.tuple(arg), key, where=where), + expression=sge.convert(1), + ) + + def visit_ArgMax(self, op, *, arg, key, where): + return sge.Dot( + this=self.agg.argMax(self.f.tuple(arg), key, where=where), + expression=sge.convert(1), + ) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index b527663a6394..8cecd30c02d1 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler): post_rewrites = (split_select_distinct_with_order_by,) UNSUPPORTED_OPS = ( - ops.ArgMax, - ops.ArgMin, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayMap, @@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_ArgMin(self, op, *, arg, key, where): + return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)]) + + def visit_ArgMax(self, op, *, arg, key, where): + return self.agg.first_value( + arg, where=where, order_by=[sge.Ordered(this=key, desc=True)] + ) + def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index 4e2710b39992..6548265b95e3 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -25,8 +25,6 @@ class DruidCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ApproxMedian, - ops.ArgMax, - ops.ArgMin, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 30a23613e52c..acb365ba607d 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -543,6 +543,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last(arg, where=where, order_by=order_by) + def visit_ArgMin(self, op, *, arg, key, where): + return self.agg.first(arg, where=where, order_by=[sge.Ordered(this=key)]) + + def visit_ArgMax(self, op, *, arg, key, where): + return self.agg.first( + arg, where=where, order_by=[sge.Ordered(this=key, desc=True)] + ) + def visit_Quantile(self, op, *, arg, quantile, where): suffix = "cont" if op.arg.dtype.is_numeric() else "disc" funcname = f"percentile_{suffix}" diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index 38bb88174eb3..f45a52cbffbb 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -32,8 +32,6 @@ class ExasolCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.AnalyticVectorizedUDF, - ops.ArgMax, - ops.ArgMin, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 9d9b0a45f69c..2cbb7163a034 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -69,8 +69,6 @@ class FlinkCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.AnalyticVectorizedUDF, ops.ApproxMedian, - ops.ArgMax, - ops.ArgMin, ops.ArrayFlatten, ops.ArrayStringJoin, ops.Correlation, diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index 06269cc4b0de..fb87f14e9cdb 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -30,8 +30,6 @@ class ImpalaCompiler(SQLGlotCompiler): } UNSUPPORTED_OPS = ( - ops.ArgMax, - ops.ArgMin, ops.ArrayPosition, ops.Array, ops.Covariance, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 900877c5117a..aaee60a20cda 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -82,8 +82,6 @@ class MSSQLCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ApproxMedian, - ops.ArgMax, - ops.ArgMin, ops.Array, ops.ArrayDistinct, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 9c2172d48237..ee96b0c95b7d 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -65,8 +65,6 @@ def POS_INF(self): NEG_INF = POS_INF UNSUPPORTED_OPS = ( ops.ApproxMedian, - ops.ArgMax, - ops.ArgMin, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 737a7515001a..5d3daf74c6f9 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -51,8 +51,6 @@ class OracleCompiler(SQLGlotCompiler): } UNSUPPORTED_OPS = ( - ops.ArgMax, - ops.ArgMin, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index d9b2016e60da..51fe532cb905 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -192,23 +192,21 @@ def visit_Mode(self, op, *, arg, where): expr = sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool): - conditions = [arg.is_(sg.not_(NULL)), key.is_(sg.not_(NULL))] - - if where is not None: - conditions.append(where) + def _argminmax(self, op, *, arg, key, where, desc: bool): + cond = key.is_(sg.not_(NULL)) + where = cond if where is None else sge.And(this=cond, expression=where) agg = self.agg.array_agg( sge.Ordered(this=sge.Order(this=arg, expressions=[key]), desc=desc), - where=sg.and_(*conditions), + where=where, ) return sge.paren(agg, copy=False)[0] def visit_ArgMin(self, op, *, arg, key, where): - return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False) + return self._argminmax(op, arg=arg, key=key, where=where, desc=False) def visit_ArgMax(self, op, *, arg, key, where): - return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=True) + return self._argminmax(op, arg=arg, key=key, where=where, desc=True) def visit_Sum(self, op, *, arg, where): arg = ( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index b42ac174fa7a..5587a5186a76 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -70,6 +70,8 @@ class PySparkCompiler(SQLGlotCompiler): } SIMPLE_OPS = { + ops.ArgMax: "max_by", + ops.ArgMin: "min_by", ops.ArrayDistinct: "array_distinct", ops.ArrayFlatten: "flatten", ops.ArrayIntersect: "array_intersect", diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index b0e5e3913cae..c01c6b79d885 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -106,6 +106,8 @@ class SnowflakeCompiler(SQLGlotCompiler): SIMPLE_OPS = { ops.All: "min", ops.Any: "max", + ops.ArgMax: "max_by", + ops.ArgMin: "min_by", ops.ArrayDistinct: "array_distinct", ops.ArrayFlatten: "array_flatten", ops.ArrayIndex: "get", diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 304438d81b09..aec56809bcad 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -206,12 +206,7 @@ def visit_ArgMax(self, *args, **kwargs): return self._visit_arg_reduction("max", *args, **kwargs) def _visit_arg_reduction(self, func, op, *, arg, key, where): - cond = arg.is_(sg.not_(NULL)) - - if op.where is not None: - cond = sg.and_(cond, where) - - agg = self.agg[func](key, where=cond) + agg = self.agg[func](key, where=where) return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]") def visit_UnwrapJSONString(self, op, *, arg): diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 8cd27c623dab..a821bc43535a 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -60,6 +60,8 @@ class TrinoCompiler(SQLGlotCompiler): SIMPLE_OPS = { ops.Arbitrary: "any_value", + ops.ArgMax: "max_by", + ops.ArgMin: "min_by", ops.Pi: "pi", ops.E: "e", ops.RegexReplace: "regexp_replace", diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 62b2142d19a3..2ff92c14f361 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -123,7 +123,6 @@ def mean_udf(s): ] argidx_not_grouped_marks = [ - "datafusion", "impala", "mysql", "mssql", @@ -411,7 +410,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -431,7 +429,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -689,6 +686,39 @@ def test_first_last_ordered(alltypes, method, filtered, include_null): assert res == sol +@pytest.mark.notimpl( + [ + "druid", + "exasol", + "flink", + "impala", + "mssql", + "mysql", + "oracle", + ], + raises=com.OperationNotDefinedError, +) +@pytest.mark.parametrize("method", ["argmin", "argmax"]) +@pytest.mark.parametrize("filtered", [True, False], ids=["filtered", "unfiltered"]) +@pytest.mark.parametrize("null_result", [True, False], ids=["null", "non-null"]) +def test_argmin_argmax(alltypes, method, filtered, null_result): + t = alltypes.mutate(by_col=_.int_col.nullif(0).nullif(9), val_col=10 * _.int_col) + + if filtered: + where = _.int_col != (1 if method == "argmin" else 8) + sol = 20 if method == "argmin" else 70 + else: + where = None + sol = 10 if method == "argmin" else 80 + + if null_result: + t = t.mutate(val_col=_.val_col.nullif(sol)) + + expr = getattr(t.val_col, method)("by_col", where=where) + res = expr.execute() + assert pd.isna(res) if null_result else res == sol + + @pytest.mark.notimpl( [ "impala", diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 8a62ff14a88c..ac1a57f94c91 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1766,6 +1766,9 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar: def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: """Return the value of `self` that maximizes `key`. + If more than one value maximizes `key`, the returned value is backend + specific. The result may be `NULL`. + Parameters ---------- key @@ -1801,6 +1804,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: """Return the value of `self` that minimizes `key`. + If more than one value minimizes `key`, the returned value is backend + specific. The result may be `NULL`. + Parameters ---------- key