Skip to content

Commit

Permalink
feat(api): add distinct option to collect
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 13, 2024
1 parent d4e8c11 commit 005f1f9
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 98 deletions.
26 changes: 17 additions & 9 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,16 +479,24 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if where is not None and include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by bigquery"
)
out = self.agg.array_agg(arg, where=where, order_by=order_by)
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if where is not None:
if include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported by bigquery"
)
if distinct:
raise com.UnsupportedOperationError(
"Combining `distinct=True` and `where` is not supported by bigquery"
)
arg = compiler.if_(where, arg, NULL)
if distinct:
arg = sge.Distinct(expressions=[arg])
if order_by:
arg = sge.Order(this=arg, expressions=order_by)
if not include_null:
out = sge.IgnoreNulls(this=out)
return out
arg = sge.IgnoreNulls(this=arg)
return self.f.array_agg(arg)

def _neg_idx_to_pos(self, arg, idx):
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,13 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
return self.f.arrayZip(*arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)
func = self.agg.groupUniqArray if distinct else self.agg.groupArray
return func(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ def visit_ArrayRepeat(self, op, *, arg, times):
def visit_ArrayPosition(self, op, *, arg, other):
return self.f.coalesce(self.f.array_position(arg, other), 0)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if distinct:
raise com.UnsupportedOperationError(
"`collect` with `distinct=True` is not supported"
)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def visit_ArrayPosition(self, op, *, arg, other):
self.f.coalesce(self.f.list_indexof(arg, other), 0),
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_ArrayIndex(self, op, *, arg, index):
Expand Down
12 changes: 8 additions & 4 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,24 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# the only way to get filtering *and* respecting nulls is to use
# `FILTER` syntax, but it's broken in various ways for other aggregates
out = self.f.array_agg(arg)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
out = self.f.array_agg(arg)
if where is not None:
out = sge.Filter(this=out, expression=sge.Where(this=where))
if distinct:
# TODO: Flink supposedly supports `ARRAY_AGG(DISTINCT ...)`, but it
# doesn't work with filtering (either `include_null=False` or
# additional filtering). Their `array_distinct` function does maintain
# ordering though, so we can use it here.
out = self.f.array_distinct(out)
return out

def visit_Strip(self, op, *, arg):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,12 @@ def visit_ArrayIntersect(self, op, *, left, right):
)
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,16 @@ def visit_ArrayContains(self, op, *, arg, other):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.concat_ws(sep, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the pyspark backend"
)
return self.agg.array_agg(arg, where=where, order_by=order_by)
if where:
arg = self.if_(where, arg, NULL)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, order_by=order_by)

def visit_StringFind(self, op, *, arg, substr, start, end):
if end is not None:
Expand Down
17 changes: 14 additions & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,25 +452,36 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])

def _array_collect(self, *, arg, where, order_by, include_null):
def _array_collect(self, *, arg, where, order_by, include_null, distinct=False):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the snowflake backend"
)
if where is not None and distinct:
raise com.UnsupportedOperationError(
"Combining `distinct=True` and `where` is not supported by snowflake"
)

if where is not None:
arg = self.if_(where, arg, NULL)

if distinct:
arg = sge.Distinct(expressions=[arg])

out = self.f.array_agg(arg)

if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

return out

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
return self._array_collect(
arg=arg, where=where, order_by=order_by, include_null=include_null
arg=arg,
where=where,
order_by=order_by,
include_null=include_null,
distinct=distinct,
)

def visit_First(self, op, *, arg, where, order_by, include_null):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,12 @@ def visit_ArrayContains(self, op, *, arg, other):
NULL,
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_JSONGetItem(self, op, *, arg, index):
Expand Down
107 changes: 58 additions & 49 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from datetime import date
from operator import methodcaller

Expand Down Expand Up @@ -1301,67 +1302,75 @@ def test_group_concat_ordered(alltypes, df, filtered):
assert result == expected


@pytest.mark.notimpl(
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
)
@pytest.mark.parametrize("filtered", [True, False])
def test_collect_ordered(alltypes, df, filtered):
ibis_cond = (_.id % 13 == 0) if filtered else None
pd_cond = (df.id % 13 == 0) if filtered else True
result = (
alltypes.filter(_.bigint_col == 10)
.id.cast("str")
.collect(where=ibis_cond, order_by=_.id.desc())
.execute()
)
expected = list(
df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str)
)
assert result == expected
def gen_test_collect_marks(distinct, filtered, ordered, include_null):
"""The marks for this test fail for different combinations of parameters.
Rather than set `strict=False` (which can let bugs sneak through), we split
the mark generation into a function"""
if distinct:
yield pytest.mark.notimpl(["datafusion"], raises=com.UnsupportedOperationError)
if ordered:
yield pytest.mark.notimpl(
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
)
if include_null:
yield pytest.mark.notimpl(
["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError
)

# Handle special cases
if filtered and distinct:
yield pytest.mark.notimpl(
["bigquery", "snowflake"],
raises=com.UnsupportedOperationError,
reason="Can't combine where and distinct",
)
elif filtered and include_null:
yield pytest.mark.notimpl(
["bigquery"],
raises=com.UnsupportedOperationError,
reason="Can't combine where and include_null",
)
elif include_null:
yield pytest.mark.notimpl(
["bigquery"],
raises=GoogleBadRequest,
reason="BigQuery can't retrieve arrays with null values",
)


@pytest.mark.notimpl(
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize("filtered", [True, False])
@pytest.mark.parametrize(
"include_null",
"distinct, filtered, ordered, include_null",
[
False,
param(
True,
marks=[
pytest.mark.notimpl(
["clickhouse", "pyspark", "snowflake"],
raises=com.UnsupportedOperationError,
reason="`include_null=True` is not supported",
),
pytest.mark.notimpl(
["bigquery"],
raises=com.UnsupportedOperationError,
reason="Can't mix `where` and `include_null=True`",
strict=False,
),
],
),
param(*ps, marks=list(gen_test_collect_marks(*ps)))
for ps in itertools.product(*([[True, False]] * 4))
],
)
def test_collect(alltypes, df, filtered, include_null):
ibis_cond = (_.id % 13 == 0) if filtered else None
pd_cond = (df.id % 13 == 0) if filtered else slice(None)
expr = (
alltypes.string_col.nullif("3")
.collect(where=ibis_cond, include_null=include_null)
.length()
def test_collect(alltypes, df, distinct, filtered, ordered, include_null):
expr = alltypes.mutate(x=_.string_col.nullif("3")).x.collect(
where=((_.id % 13 == 0) if filtered else None),
include_null=include_null,
distinct=distinct,
order_by=(_.x.desc() if ordered else ()),
)
res = expr.execute()
vals = df.string_col if include_null else df.string_col[df.string_col != "3"]
sol = len(vals[pd_cond])

x = df.string_col.where(df.string_col != "3", None)
if filtered:
x = x[df.id % 13 == 0]
if not include_null:
x = x.dropna()
if distinct:
x = x.drop_duplicates()
sol = sorted(x, key=lambda x: (x is not None, x), reverse=True)

if not ordered:
# If unordered, order afterwards so we can compare
res = sorted(res, key=lambda x: (x is not None, x), reverse=True)

assert res == sol


Expand Down
11 changes: 10 additions & 1 deletion ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.annotations import ValidationError, attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Column, Value
from ibis.expr.operations.relations import Relation # noqa: TCH001
Expand Down Expand Up @@ -376,6 +376,15 @@ class ArrayCollect(Filterable, Reduction):
arg: Column
order_by: VarTuple[SortKey] = ()
include_null: bool = False
distinct: bool = False

def __init__(self, arg, order_by, distinct, **kwargs):
if distinct and order_by and [arg] != [key.expr for key in order_by]:
raise ValidationError(
"`collect` with `order_by` and `distinct=True` and may only "
"order by the collected column"
)
super().__init__(arg=arg, order_by=order_by, distinct=distinct, **kwargs)

@attribute
def dtype(self):
Expand Down
20 changes: 20 additions & 0 deletions ibis/expr/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ibis
import ibis.expr.operations as ops
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.common.deferred import Deferred
from ibis.common.exceptions import IbisTypeError

Expand Down Expand Up @@ -161,3 +162,22 @@ def test_ordered_aggregations_no_order(method):
q3 = func(order_by=())
assert q1.equals(q2)
assert q1.equals(q3)


def test_collect_distinct():
t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t")
# Fine
t.a.collect(distinct=True)
t.a.collect(distinct=True, order_by=t.a.desc())
(t.a + 1).collect(distinct=True, order_by=(t.a + 1).desc())

with pytest.raises(ValidationError, match="only order by the collected column"):
t.b.collect(distinct=True, order_by=t.a)
with pytest.raises(ValidationError, match="only order by the collected column"):
t.b.collect(
distinct=True,
order_by=(
t.a,
t.b,
),
)
Loading

0 comments on commit 005f1f9

Please sign in to comment.