diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index b40c85b0fa4af..a9f7af348be16 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -479,16 +479,24 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): - if where is not None and include_null: - raise com.UnsupportedOperationError( - "Combining `include_null=True` and `where` is not supported " - "by bigquery" - ) - out = self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): + if where is not None: + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported by bigquery" + ) + if distinct: + raise com.UnsupportedOperationError( + "Combining `distinct=True` and `where` is not supported by bigquery" + ) + arg = compiler.if_(where, arg, NULL) + if distinct: + arg = sge.Distinct(expressions=[arg]) + if order_by: + arg = sge.Order(this=arg, expressions=order_by) if not include_null: - out = sge.IgnoreNulls(this=out) - return out + arg = sge.IgnoreNulls(this=arg) + return self.f.array_agg(arg) def _neg_idx_to_pos(self, arg, idx): return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index ea15150b279b2..f7fa8af6fd87d 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -611,12 +611,13 @@ def visit_ArrayUnion(self, op, *, left, right): def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: return self.f.arrayZip(*arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if include_null: raise com.UnsupportedOperationError( "`include_null=True` is not supported by the clickhouse backend" ) - return self.agg.groupArray(arg, where=where, order_by=order_by) + func = self.agg.groupUniqArray if distinct else self.agg.groupArray + return func(arg, where=where, order_by=order_by) def visit_First(self, op, *, arg, where, order_by, include_null): if include_null: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index 155bf45d65847..b60935f6cdd90 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -327,7 +327,11 @@ def visit_ArrayRepeat(self, op, *, arg, times): def visit_ArrayPosition(self, op, *, arg, other): return self.f.coalesce(self.f.array_position(arg, other), 0) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): + if distinct: + raise com.UnsupportedOperationError( + "`collect` with `distinct=True` is not supported" + ) if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index c1011ea9f932d..76082a5c29eb2 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -156,10 +156,12 @@ def visit_ArrayPosition(self, op, *, arg, other): self.f.coalesce(self.f.list_indexof(arg, other), 0), ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) + if distinct: + arg = sge.Distinct(expressions=[arg]) return self.agg.array_agg(arg, where=where, order_by=order_by) def visit_ArrayIndex(self, op, *, arg, index): diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 96c31e61f5473..ff50258b327cd 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -572,20 +572,24 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right): def visit_StructColumn(self, op, *, names, values): return self.cast(sge.Struct(expressions=list(values)), op.dtype) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if order_by: raise com.UnsupportedOperationError( "ordering of order-sensitive aggregations via `order_by` is " "not supported for this backend" ) - # the only way to get filtering *and* respecting nulls is to use - # `FILTER` syntax, but it's broken in various ways for other aggregates - out = self.f.array_agg(arg) if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) + out = self.f.array_agg(arg) if where is not None: out = sge.Filter(this=out, expression=sge.Where(this=where)) + if distinct: + # TODO: Flink supposedly supports `ARRAY_AGG(DISTINCT ...)`, but it + # doesn't work with filtering (either `include_null=False` or + # additional filtering). Their `array_distinct` function does maintain + # ordering though, so we can use it here. + out = self.f.array_distinct(out) return out def visit_Strip(self, op, *, arg): diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index b224d1d180e12..7197d6fd03df9 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -372,10 +372,12 @@ def visit_ArrayIntersect(self, op, *, left, right): ) ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) + if distinct: + arg = sge.Distinct(expressions=[arg]) return self.agg.array_agg(arg, where=where, order_by=order_by) def visit_First(self, op, *, arg, where, order_by, include_null): diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 36599f8d21e5c..374d72ce143d9 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -432,12 +432,16 @@ def visit_ArrayContains(self, op, *, arg, other): def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.concat_ws(sep, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if include_null: raise com.UnsupportedOperationError( "`include_null=True` is not supported by the pyspark backend" ) - return self.agg.array_agg(arg, where=where, order_by=order_by) + if where: + arg = self.if_(where, arg, NULL) + if distinct: + arg = sge.Distinct(expressions=[arg]) + return self.agg.array_agg(arg, order_by=order_by) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index ad94da56433ff..927853c1d8bef 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -452,15 +452,22 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by, include_null): + def _array_collect(self, *, arg, where, order_by, include_null, distinct=False): if include_null: raise com.UnsupportedOperationError( "`include_null=True` is not supported by the snowflake backend" ) + if where is not None and distinct: + raise com.UnsupportedOperationError( + "Combining `distinct=True` and `where` is not supported by snowflake" + ) if where is not None: arg = self.if_(where, arg, NULL) + if distinct: + arg = sge.Distinct(expressions=[arg]) + out = self.f.array_agg(arg) if order_by: @@ -468,9 +475,13 @@ def _array_collect(self, *, arg, where, order_by, include_null): return out - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): return self._array_collect( - arg=arg, where=where, order_by=order_by, include_null=include_null + arg=arg, + where=where, + order_by=order_by, + include_null=include_null, + distinct=distinct, ) def visit_First(self, op, *, arg, where, order_by, include_null): diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 56452a1b4dbfb..21b652ba6b91a 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -182,10 +182,12 @@ def visit_ArrayContains(self, op, *, arg, other): NULL, ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) + if distinct: + arg = sge.Distinct(expressions=[arg]) return self.agg.array_agg(arg, where=where, order_by=order_by) def visit_JSONGetItem(self, op, *, arg, index): diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 6e7a292241054..b97e3fb646473 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from datetime import date from operator import methodcaller @@ -1301,67 +1302,75 @@ def test_group_concat_ordered(alltypes, df, filtered): assert result == expected -@pytest.mark.notimpl( - ["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"], - raises=com.OperationNotDefinedError, -) -@pytest.mark.notimpl( - ["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError -) -@pytest.mark.parametrize("filtered", [True, False]) -def test_collect_ordered(alltypes, df, filtered): - ibis_cond = (_.id % 13 == 0) if filtered else None - pd_cond = (df.id % 13 == 0) if filtered else True - result = ( - alltypes.filter(_.bigint_col == 10) - .id.cast("str") - .collect(where=ibis_cond, order_by=_.id.desc()) - .execute() - ) - expected = list( - df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str) - ) - assert result == expected +def gen_test_collect_marks(distinct, filtered, ordered, include_null): + """The marks for this test fail for different combinations of parameters. + Rather than set `strict=False` (which can let bugs sneak through), we split + the mark generation into a function""" + if distinct: + yield pytest.mark.notimpl(["datafusion"], raises=com.UnsupportedOperationError) + if ordered: + yield pytest.mark.notimpl( + ["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError + ) + if include_null: + yield pytest.mark.notimpl( + ["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError + ) + + # Handle special cases + if filtered and distinct: + yield pytest.mark.notimpl( + ["bigquery", "snowflake"], + raises=com.UnsupportedOperationError, + reason="Can't combine where and distinct", + ) + elif filtered and include_null: + yield pytest.mark.notimpl( + ["bigquery"], + raises=com.UnsupportedOperationError, + reason="Can't combine where and include_null", + ) + elif include_null: + yield pytest.mark.notimpl( + ["bigquery"], + raises=GoogleBadRequest, + reason="BigQuery can't retrieve arrays with null values", + ) @pytest.mark.notimpl( ["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"], raises=com.OperationNotDefinedError, ) -@pytest.mark.parametrize("filtered", [True, False]) @pytest.mark.parametrize( - "include_null", + "distinct, filtered, ordered, include_null", [ - False, - param( - True, - marks=[ - pytest.mark.notimpl( - ["clickhouse", "pyspark", "snowflake"], - raises=com.UnsupportedOperationError, - reason="`include_null=True` is not supported", - ), - pytest.mark.notimpl( - ["bigquery"], - raises=com.UnsupportedOperationError, - reason="Can't mix `where` and `include_null=True`", - strict=False, - ), - ], - ), + param(*ps, marks=list(gen_test_collect_marks(*ps))) + for ps in itertools.product(*([[True, False]] * 4)) ], ) -def test_collect(alltypes, df, filtered, include_null): - ibis_cond = (_.id % 13 == 0) if filtered else None - pd_cond = (df.id % 13 == 0) if filtered else slice(None) - expr = ( - alltypes.string_col.nullif("3") - .collect(where=ibis_cond, include_null=include_null) - .length() +def test_collect(alltypes, df, distinct, filtered, ordered, include_null): + expr = alltypes.mutate(x=_.string_col.nullif("3")).x.collect( + where=((_.id % 13 == 0) if filtered else None), + include_null=include_null, + distinct=distinct, + order_by=(_.x.desc() if ordered else ()), ) res = expr.execute() - vals = df.string_col if include_null else df.string_col[df.string_col != "3"] - sol = len(vals[pd_cond]) + + x = df.string_col.where(df.string_col != "3", None) + if filtered: + x = x[df.id % 13 == 0] + if not include_null: + x = x.dropna() + if distinct: + x = x.drop_duplicates() + sol = sorted(x, key=lambda x: (x is not None, x), reverse=True) + + if not ordered: + # If unordered, order afterwards so we can compare + res = sorted(res, key=lambda x: (x is not None, x), reverse=True) + assert res == sol diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 6d780e847e213..a01413ac546fe 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -9,7 +9,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common.annotations import attribute +from ibis.common.annotations import ValidationError, attribute from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Column, Value from ibis.expr.operations.relations import Relation # noqa: TCH001 @@ -376,6 +376,15 @@ class ArrayCollect(Filterable, Reduction): arg: Column order_by: VarTuple[SortKey] = () include_null: bool = False + distinct: bool = False + + def __init__(self, arg, order_by, distinct, **kwargs): + if distinct and order_by and [arg] != [key.expr for key in order_by]: + raise ValidationError( + "`collect` with `order_by` and `distinct=True` and may only " + "order by the collected column" + ) + super().__init__(arg=arg, order_by=order_by, distinct=distinct, **kwargs) @attribute def dtype(self): diff --git a/ibis/expr/tests/test_reductions.py b/ibis/expr/tests/test_reductions.py index 1615a4c8f74e3..adfcfb95008aa 100644 --- a/ibis/expr/tests/test_reductions.py +++ b/ibis/expr/tests/test_reductions.py @@ -6,6 +6,7 @@ import ibis import ibis.expr.operations as ops from ibis import _ +from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred from ibis.common.exceptions import IbisTypeError @@ -161,3 +162,22 @@ def test_ordered_aggregations_no_order(method): q3 = func(order_by=()) assert q1.equals(q2) assert q1.equals(q3) + + +def test_collect_distinct(): + t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t") + # Fine + t.a.collect(distinct=True) + t.a.collect(distinct=True, order_by=t.a.desc()) + (t.a + 1).collect(distinct=True, order_by=(t.a + 1).desc()) + + with pytest.raises(ValidationError, match="only order by the collected column"): + t.b.collect(distinct=True, order_by=t.a) + with pytest.raises(ValidationError, match="only order by the collected column"): + t.b.collect( + distinct=True, + order_by=( + t.a, + t.b, + ), + ) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 76d357484c7ef..9efb1252087ba 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1022,6 +1022,7 @@ def collect( where: ir.BooleanValue | None = None, order_by: Any = None, include_null: bool = False, + distinct: bool = False, ) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. @@ -1039,19 +1040,22 @@ def collect( include_null Whether to include null values when performing this aggregation. Set to `True` to include nulls in the result. + distinct + Whether to collect only distinct elements. Returns ------- ArrayScalar - Collected array + An array of all the collected elements. Examples -------- Basic collect usage >>> import ibis + >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable({"key": list("aaabb"), "value": [1, 2, 3, 4, 5]}) + >>> t = ibis.memtable({"key": list("aaabb"), "value": [1, 1, 2, 3, 5]}) >>> t ┏━━━━━━━━┳━━━━━━━┓ ┃ key ┃ value ┃ @@ -1059,40 +1063,37 @@ def collect( │ string │ int64 │ ├────────┼───────┤ │ a │ 1 │ + │ a │ 1 │ │ a │ 2 │ - │ a │ 3 │ - │ b │ 4 │ + │ b │ 3 │ │ b │ 5 │ └────────┴───────┘ - >>> t.value.collect() - ┌────────────────┐ - │ [1, 2, ... +3] │ - └────────────────┘ - >>> type(t.value.collect()) - - Collect elements per group + Collect all elements into an array scalar: - >>> t.group_by("key").agg(v=lambda t: t.value.collect()).order_by("key") - ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ key ┃ v ┃ - ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ array │ - ├────────┼──────────────────────┤ - │ a │ [1, 2, ... +1] │ - │ b │ [4, 5] │ - └────────┴──────────────────────┘ + >>> t.value.collect().to_pandas() + [1, 1, 2, 3, 5] + + Collect only unique elements: + + >>> t.value.collect(distinct=True).to_pandas() # doctest: +SKIP + [1, 2, 3, 5] + + Collect elements in a specified order: + + >>> t.value.collect(order_by=_.value.desc()).to_pandas() + [5, 3, 2, 1, 1] - Collect elements per group using a filter + Collect elements per group, filtering out values <= 1: - >>> t.group_by("key").agg(v=lambda t: t.value.collect(where=t.value > 1)).order_by("key") + >>> t.group_by("key").agg(v=t.value.collect(where=_.value > 1)).order_by("key") ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃ key ┃ v ┃ ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ string │ array │ ├────────┼──────────────────────┤ - │ a │ [2, 3] │ - │ b │ [4, 5] │ + │ a │ [2] │ + │ b │ [3, 5] │ └────────┴──────────────────────┘ """ return ops.ArrayCollect( @@ -1100,6 +1101,7 @@ def collect( where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), include_null=include_null, + distinct=distinct, ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: