diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index f3c1cdaf4d12..5e7a79453701 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -319,6 +319,17 @@ def visit(cls, op: ops.StandardDev, arg, where, how): ddof = {"pop": 0, "sample": 1}[how] return cls.agg(lambda x: x.std(ddof=ddof), arg, where) + @classmethod + def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + return cls.agg( + (lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where + ) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if where is None: diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index 61dc278c5562..c6feb6c3e886 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -291,7 +291,6 @@ def last(arg): ops.Arbitrary: first, ops.CountDistinct: lambda x: x.nunique(), ops.ApproxCountDistinct: lambda x: x.nunique(), - ops.ArrayCollect: lambda x: x.dropna().tolist(), } diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index ee7ba1f03b69..5351b6cf3a2b 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -989,7 +989,7 @@ def array_column(op, **kw): def array_collect(op, in_group_by=False, **kw): arg = translate(op.arg, **kw) - predicate = arg.is_not_null() + predicate = arg.is_not_null() if op.ignore_null else True if op.where is not None: predicate &= translate(op.where, **kw) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 6e201b52fa71..418268b8b98c 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -313,7 +313,6 @@ class SQLGlotCompiler(abc.ABC): ops.ApproxCountDistinct: "approx_distinct", ops.ArgMax: "max_by", ops.ArgMin: "min_by", - ops.ArrayCollect: "array_agg", ops.ArrayContains: "array_contains", ops.ArrayFlatten: "flatten", ops.ArrayLength: "array_size", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index e0c225d45323..b5818e35468c 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -440,10 +440,11 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by): - return sge.IgnoreNulls( - this=self.agg.array_agg(arg, where=where, order_by=order_by) - ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + out = self.agg.array_agg(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out def _neg_idx_to_pos(self, arg, idx): return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index ec94f7f4ebc0..00236f96df99 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -61,7 +61,6 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.Arbitrary: "any", ops.ArgMax: "argMax", ops.ArgMin: "argMin", - ops.ArrayCollect: "groupArray", ops.ArrayContains: "has", ops.ArrayFlatten: "arrayFlatten", ops.ArrayIntersect: "arrayIntersect", @@ -604,6 +603,13 @@ def visit_ArrayUnion(self, op, *, left, right): def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: return self.f.arrayZip(*arg) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the pyspark backend" + ) + return self.agg.groupArray(arg, where=where, order_by=order_by) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index e3c4ba65478d..acc6dfc22431 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -14,7 +14,6 @@ from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType @@ -25,11 +24,6 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = DataFusion type_mapper = DataFusionType - rewrites = ( - exclude_nulls_from_array_collect, - *SQLGlotCompiler.rewrites, - ) - agg = AggGen(supports_filter=True, supports_order_by=True) UNSUPPORTED_OPS = ( @@ -331,6 +325,12 @@ def visit_ArrayRepeat(self, op, *, arg, times): def visit_ArrayPosition(self, op, *, arg, other): return self.f.coalesce(self.f.array_position(arg, other), 0) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_Covariance(self, op, *, left, right, how, where): x = op.left if x.dtype.is_boolean(): diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index d1571363e14e..6479628d9acb 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -27,7 +27,6 @@ class DruidCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 1724885ec0bb..7498ab1f0277 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -14,7 +14,6 @@ from ibis import util from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.util import gen_name if TYPE_CHECKING: @@ -43,11 +42,6 @@ class DuckDBCompiler(SQLGlotCompiler): agg = AggGen(supports_filter=True, supports_order_by=True) - rewrites = ( - exclude_nulls_from_array_collect, - *SQLGlotCompiler.rewrites, - ) - LOWERED_OPS = { ops.Sample: None, ops.StringSlice: None, @@ -154,6 +148,12 @@ def visit_ArrayDistinct(self, op, *, arg): ), ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_ArrayIndex(self, op, *, arg, index): return self.f.list_extract(arg, index + self.cast(index >= 0, op.index.dtype)) diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index 87f5aaa543d3..bdf690a7ff2b 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -35,7 +35,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 4e4fb1586415..fb1fad34cfff 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -71,7 +71,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayFlatten, ops.ArraySort, ops.ArrayStringJoin, diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index f73a38751d08..bae861126d16 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -26,7 +26,6 @@ class ImpalaCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayPosition, ops.Array, ops.Covariance, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 0c6fe9a567ac..d7b815252e9e 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -75,7 +75,6 @@ class MSSQLCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayDistinct, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 5244e2642b52..56c6f799c89a 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -67,7 +67,6 @@ def POS_INF(self): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index ee98dda5e842..98bea4d2d90b 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -51,7 +51,6 @@ class OracleCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 9d47044aa835..f145ac779111 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -17,7 +17,6 @@ from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.common.exceptions import InvalidDecoratorError from ibis.util import gen_name @@ -43,8 +42,6 @@ class PostgresCompiler(SQLGlotCompiler): dialect = Postgres type_mapper = PostgresType - rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites) - agg = AggGen(supports_filter=True, supports_order_by=True) NAN = sge.Literal.number("'NaN'::double precision") @@ -358,6 +355,12 @@ def visit_ArrayIntersect(self, op, *, left, right): ) ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_Log2(self, op, *, arg): return self.cast( self.f.log( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index ee6060e4bfcc..9a4c292d5d11 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -397,6 +397,13 @@ def visit_ArrayContains(self, op, *, arg, other): def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.concat_ws(sep, arg) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the pyspark backend" + ) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 52cffb0a3697..f6fcad7b26d4 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -455,7 +455,12 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by): + def _array_collect(self, *, arg, where, order_by, ignore_null=True): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the snowflake backend" + ) + if where is not None: arg = self.if_(where, arg, NULL) @@ -466,8 +471,10 @@ def _array_collect(self, *, arg, where, order_by): return out - def visit_ArrayCollect(self, op, *, arg, where, order_by): - return self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + return self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) def visit_First(self, op, *, arg, where, order_by): out = self._array_collect(arg=arg, where=where, order_by=order_by) diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 88f86c5bb979..9f0d4bf7c1dc 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -42,7 +42,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.Array, ops.ArrayConcat, ops.ArrayStringJoin, - ops.ArrayCollect, ops.ArrayContains, ops.ArrayFlatten, ops.ArrayLength, diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 90ee114c893c..f35653c79712 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -22,7 +22,6 @@ from ibis.backends.sql.datatypes import TrinoType from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import ( - exclude_nulls_from_array_collect, exclude_unsupported_window_frame_from_ops, ) from ibis.util import gen_name @@ -37,7 +36,6 @@ class TrinoCompiler(SQLGlotCompiler): agg = AggGen(supports_filter=True, supports_order_by=True) rewrites = ( - exclude_nulls_from_array_collect, exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites, ) @@ -178,6 +176,12 @@ def visit_ArrayContains(self, op, *, arg, other): NULL, ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_JSONGetItem(self, op, *, arg, index): fmt = "%d" if op.index.dtype.is_integer() else '"%s"' return self.f.json_extract(arg, self.f.format(f"$[{fmt}]", index)) diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 511052309305..7e01999144ec 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -386,14 +386,6 @@ def exclude_unsupported_window_frame_from_ops(_, **kwargs): return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) -@replace(p.ArrayCollect) -def exclude_nulls_from_array_collect(_, **kwargs): - where = ops.NotNull(_.arg) - if _.where is not None: - where = ops.And(where, _.where) - return _.copy(where=where) - - # Rewrite rules for lowering a high-level operation into one composed of more # primitive operations. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index c2caf15dfa8b..836cef629b80 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -530,28 +530,6 @@ def mean_and_std(v): lambda t, where: len(t[where]), id="count_star", ), - param( - lambda t, where: t.string_col.nullif("3").collect(where=where), - lambda t, where: t.string_col[t.string_col != "3"][where].tolist(), - id="collect", - marks=[ - pytest.mark.notimpl( - ["impala", "mysql", "sqlite", "mssql", "druid", "oracle", "exasol"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl( - ["dask"], - raises=(AttributeError, TypeError), - reason=( - "For 'is_in' case: 'Series' object has no attribute 'arraycollect'" - "For 'no_cond' case: TypeError: Object " - " is not " - "callable or a string" - ), - ), - pytest.mark.notyet(["flink"], raises=com.OperationNotDefinedError), - ], - ), ], ) @pytest.mark.parametrize( @@ -1398,6 +1376,59 @@ def test_collect_ordered(alltypes, df, filtered): assert result == expected +@pytest.mark.notimpl( + ["druid", "exasol", "flink", "impala", "mssql", "mysql", "oracle", "sqlite"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl( + ["dask"], raises=AttributeError, reason="Dask doesn't implement tolist()" +) +@pytest.mark.parametrize( + "filtered", + [ + param( + True, + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=Exception, + reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + ) + ], + ), + False, + ], +) +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + ["clickhouse", "pyspark", "snowflake"], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ) + ], + ), + ], +) +def test_collect(alltypes, df, filtered, ignore_null): + ibis_cond = (_.id % 13 == 0) if filtered else None + pd_cond = (df.id % 13 == 0) if filtered else slice(None) + res = ( + alltypes.string_col.nullif("3") + .collect(where=ibis_cond, ignore_null=ignore_null) + .length() + .execute() + ) + vals = df.string_col[(df.string_col != "3")] if ignore_null else df.string_col + sol = len(vals[pd_cond]) + assert res == sol + + @pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) def test_topk_op(alltypes, df): # TopK expression will order rows by "count" but each backend diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 0e093f536b45..f5b1128a8694 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -368,6 +368,7 @@ class ArrayCollect(Filterable, Reduction): arg: Column order_by: VarTuple[SortKey] = () + ignore_null: bool = True @attribute def dtype(self): diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 57d4c0cf8c48..c1cf5c07a0df 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1017,7 +1017,10 @@ def cases( return builder.else_(default).end() def collect( - self, where: ir.BooleanValue | None = None, order_by: Any = None + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, ) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. @@ -1032,6 +1035,9 @@ def collect( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the order of the items in the result is undefined and backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Returns ------- @@ -1092,6 +1098,7 @@ def collect( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: