Skip to content

Commit

Permalink
feat(backend): make ArrayCollect filterable
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Nov 10, 2022
1 parent 90b9bc8 commit 1e1a5cf
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 18 deletions.
13 changes: 12 additions & 1 deletion ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def formatter(translator, op):
return formatter


def _array_collect(translator, op):
func_name = "groupArray"
args = [translator.translate(op.arg)]

if (where := op.where) is not None:
args.append(translator.translate(where))
func_name += "If"

return f"{func_name}({', '.join(args)})"


def _count_star(translator, op):
# zero argument count == count(*), countIf when `where` is not None
return _aggregate(translator, "count", where=op.where)
Expand Down Expand Up @@ -717,7 +728,7 @@ def _sort_key(translator, op):
ops.Min: _agg('min'),
ops.ArgMin: _agg('argMin'),
ops.ArgMax: _agg('argMax'),
ops.ArrayCollect: _agg('groupArray'),
ops.ArrayCollect: _array_collect,
ops.StandardDev: _agg_variance_like('stddev'),
ops.Variance: _agg_variance_like('var'),
ops.Covariance: _agg_variance_like('covar'),
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def execute_array_column(op, cols, **kwargs):


# TODO - aggregations - #2553
@execute_node.register(ops.ArrayCollect, dd.Series)
def execute_array_collect(op, data, aggcontext=None, **kwargs):
@execute_node.register(ops.ArrayCollect, dd.Series, type(None))
def execute_array_collect(op, data, where, aggcontext=None, **kwargs):
return aggcontext.agg(data, collect_list)


@execute_node.register(ops.ArrayCollect, ddgb.SeriesGroupBy)
def execute_array_collect_grouped_series(op, data, aggcontext=None, **kwargs):
@execute_node.register(ops.ArrayCollect, ddgb.SeriesGroupBy, type(None))
def execute_array_collect_grouped_series(op, data, where, **kwargs):
return data.agg(collect_list)
18 changes: 15 additions & 3 deletions ibis/backends/pandas/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def execute_array_repeat_scalar(op, data, n, **kwargs):
return np.tile(data, max(n, 0))


@execute_node.register(ops.ArrayCollect, (pd.Series, SeriesGroupBy))
def execute_array_collect(op, data, aggcontext=None, **kwargs):
return aggcontext.agg(data, np.array)
@execute_node.register(ops.ArrayCollect, pd.Series, (type(None), pd.Series))
def execute_array_collect(op, data, where, aggcontext=None, **kwargs):
return aggcontext.agg(data.loc[where] if where is not None else data, np.array)


@execute_node.register(ops.ArrayCollect, SeriesGroupBy, (type(None), pd.Series))
def execute_array_collect_groupby(op, data, where, aggcontext=None, **kwargs):
return aggcontext.agg(
(
data.obj.loc[where].groupby(data.grouping.grouper)
if where is not None
else data
),
np.array,
)
2 changes: 2 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def array_column(op):
@translate.register(ops.ArrayCollect)
def array_collect(op):
arg = translate(op.arg)
if (where := op.where) is not None:
arg = arg.filter(translate(where))
return arg.list()


Expand Down
18 changes: 11 additions & 7 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def _covar(t, op):

def _mode(t, op):
arg = op.arg
if op.where is not None:
arg = op.where.to_expr().ifelse(arg, None).op()
if (where := op.where) is not None:
arg = ops.Where(where, arg, None)
return sa.func.mode().within_group(t.translate(arg))


Expand All @@ -506,16 +506,20 @@ def variance_compiler(t, op):
result = func(t.translate(x), t.translate(y))

if (where := op.where) is not None:
if t._has_reduction_filter_syntax:
result = result.filter(t.translate(where))
else:
result = sa.case((t.translate(where), result), else_=sa.null())
return result.filter(t.translate(where))

return result

return variance_compiler


def _array_collect(t, op):
result = sa.func.array_agg(t.translate(op.arg))
if (where := op.where) is not None:
return result.filter(t.translate(where))
return result


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -585,7 +589,7 @@ def variance_compiler(t, op):
ops.CumulativeAny: unary(sa.func.bool_or),
# array operations
ops.ArrayLength: unary(_cardinality),
ops.ArrayCollect: unary(sa.func.array_agg),
ops.ArrayCollect: _array_collect,
ops.ArrayColumn: _array_column,
ops.ArraySlice: _array_slice,
ops.ArrayIndex: fixed_arity(
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,8 @@ def compile_array_repeat(t, op, **kwargs):
@compiles(ops.ArrayCollect)
def compile_array_collect(t, op, **kwargs):
src_column = t.translate(op.arg, **kwargs)
if (where := op.where) is not None:
src_column = F.when(t.translate(where, **kwargs), src_column)
return F.collect_list(src_column)


Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,16 @@ def mean_and_std(v):
lambda t, where: len(t[where]),
id='count_star',
),
param(
lambda t, where: t.string_col.collect(where=where),
lambda t, where: list(t.string_col[where]),
id="collect",
marks=[
pytest.mark.notimpl(
["impala", "datafusion", "snowflake", "dask", "polars"]
)
],
),
],
)
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class CountDistinct(Filterable, Reduction):


@public
class ArrayCollect(Reduction):
class ArrayCollect(Filterable, Reduction):
arg = rlz.column(rlz.any)

@attribute.default
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def cases(
builder = builder.when(case, result)
return builder.else_(default).end()

def collect(self) -> ir.ArrayValue:
def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayValue:
"""Return an array of the elements of this expression."""
return ops.ArrayCollect(self).to_expr()
return ops.ArrayCollect(self, where=where).to_expr()

def identical_to(self, other: Value) -> ir.BooleanValue:
"""Return whether this expression is identical to other.
Expand Down

0 comments on commit 1e1a5cf

Please sign in to comment.