Skip to content

Commit

Permalink
feat(api): support order_by in order-sensitive aggregates (`collect…
Browse files Browse the repository at this point in the history
…`/`group_concat`/`first`/`last`) (#9729)

Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
jcrist and cpcloud authored Aug 2, 2024
1 parent 7d38f09 commit a18cb5d
Show file tree
Hide file tree
Showing 30 changed files with 537 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArray("t0"."string_col"))
THEN NULL
ELSE arrayStringConcat(groupArray("t0"."string_col"), ',')
END AS "GroupConcat(string_col, ',')"
END AS "GroupConcat(string_col, ',', ())"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0))
THEN NULL
ELSE arrayStringConcat(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0), ',')
END AS "GroupConcat(string_col, ',', Equals(bool_col, 0))"
END AS "GroupConcat(string_col, ',', (), Equals(bool_col, 0))"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArray("t0"."string_col"))
THEN NULL
ELSE arrayStringConcat(groupArray("t0"."string_col"), '-')
END AS "GroupConcat(string_col, '-')"
END AS "GroupConcat(string_col, '-', ())"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col)`
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col, ())`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col)`
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col, ())`
FROM `functional_alltypes` AS `t0`
21 changes: 18 additions & 3 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
plan,
)
from ibis.common.dispatch import Dispatched
from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError
from ibis.common.exceptions import (
OperationNotDefinedError,
UnboundExpressionError,
UnsupportedOperationError,
)
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import any_of, gen_name

Expand Down Expand Up @@ -253,7 +257,12 @@ def visit(
############################# Reductions ##################################

@classmethod
def visit(cls, op: ops.Reduction, arg, where):
def visit(cls, op: ops.Reduction, arg, where, order_by=()):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
func = cls.kernels.reductions[type(op)]
return cls.agg(func, arg, where)

Expand Down Expand Up @@ -344,7 +353,13 @@ def agg(df):
return agg

@classmethod
def visit(cls, op: ops.GroupConcat, arg, sep, where):
def visit(cls, op: ops.GroupConcat, arg, sep, where, order_by):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

if where is None:

def agg(df):
Expand Down
50 changes: 41 additions & 9 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _literal_value(op, nan_as_none=False):


@singledispatch
def translate(expr, *, ctx):
def translate(expr, **_):
raise NotImplementedError(expr)


Expand Down Expand Up @@ -748,6 +748,11 @@ def execute_first_last(op, **kw):

arg = arg.filter(predicate)

if order_by := getattr(op, "order_by", ()):
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
descending = [k.descending for k in order_by]
arg = arg.sort_by(keys, descending=descending)

return arg.last() if isinstance(op, ops.Last) else arg.first()


Expand Down Expand Up @@ -985,14 +990,21 @@ def array_column(op, **kw):
@translate.register(ops.ArrayCollect)
def array_collect(op, in_group_by=False, **kw):
arg = translate(op.arg, **kw)
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
out = arg.drop_nulls()
if not in_group_by:
# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
out = out.implode()
return out

predicate = arg.is_not_null()
if op.where is not None:
predicate &= translate(op.where, **kw)

arg = arg.filter(predicate)

if op.order_by:
keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by]
descending = [k.descending for k in op.order_by]
arg = arg.sort_by(keys, descending=descending)

# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
return arg if in_group_by else arg.implode()


@translate.register(ops.ArrayFlatten)
Expand Down Expand Up @@ -1390,3 +1402,23 @@ def execute_array_all(op, **kw):
arg = translate(op.arg, **kw)
no_nulls = arg.list.drop_nulls()
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.all())


@translate.register(ops.GroupConcat)
def execute_group_concat(op, **kw):
arg = translate(op.arg, **kw)
sep = _literal_value(op.sep)

predicate = arg.is_not_null()

if (where := op.where) is not None:
predicate &= translate(where, **kw)

arg = arg.filter(predicate)

if order_by := op.order_by:
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
descending = [k.descending for k in order_by]
arg = arg.sort_by(keys, descending=descending)

return pl.when(arg.count() > 0).then(arg.str.join(sep)).otherwise(None)
47 changes: 33 additions & 14 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class AggGen:
supports_filter
Whether the backend supports a FILTER clause in the aggregate.
Defaults to False.
supports_order_by
Whether the backend supports an ORDER BY clause in (relevant)
aggregates. Defaults to False.
"""

class _Accessor:
Expand All @@ -79,10 +82,13 @@ def __getattr__(self, name: str) -> Callable:

__getitem__ = __getattr__

__slots__ = ("supports_filter",)
__slots__ = ("supports_filter", "supports_order_by")

def __init__(self, *, supports_filter: bool = False):
def __init__(
self, *, supports_filter: bool = False, supports_order_by: bool = False
):
self.supports_filter = supports_filter
self.supports_order_by = supports_order_by

def __get__(self, instance, owner=None):
if instance is None:
Expand All @@ -96,6 +102,7 @@ def aggregate(
name: str,
*args: Any,
where: Any = None,
order_by: tuple = (),
):
"""Compile the specified aggregate.
Expand All @@ -109,21 +116,31 @@ def aggregate(
Any arguments to pass to the aggregate.
where
An optional column filter to apply before performing the aggregate.
order_by
Optional ordering keys to use to order the rows before performing
the aggregate.
"""
func = compiler.f[name]

if where is None:
return func(*args)

if self.supports_filter:
return sge.Filter(
this=func(*args),
expression=sge.Where(this=where),
if order_by and not self.supports_order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
f"not supported for the {compiler.dialect} backend"
)
else:

if where is not None and not self.supports_filter:
args = tuple(compiler.if_(where, arg, NULL) for arg in args)
return func(*args)

if order_by and self.supports_order_by:
*rest, last = args
out = func(*rest, sge.Order(this=last, expressions=order_by))
else:
out = func(*args)

if where is not None and self.supports_filter:
out = sge.Filter(this=out, expression=sge.Where(this=where))

return out


class VarGen:
Expand Down Expand Up @@ -424,8 +441,10 @@ def make_impl(op, target_name):

if issubclass(op, ops.Reduction):

def impl(self, _, *, _name: str = target_name, where, **kw):
return self.agg[_name](*kw.values(), where=where)
def impl(
self, _, *, _name: str = target_name, where, order_by=(), **kw
):
return self.agg[_name](*kw.values(), where=where, order_by=order_by)

else:

Expand Down
33 changes: 24 additions & 9 deletions ibis/backends/sql/compilers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
Expand All @@ -28,6 +28,9 @@ class BigQueryCompiler(SQLGlotCompiler):
dialect = BigQuery
type_mapper = BigQueryType
udf_type_mapper = BigQueryUDFType

agg = AggGen(supports_order_by=True)

rewrites = (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -172,10 +175,14 @@ def visit_TimestampDelta(self, op, *, left, right, part):
"timestamp difference with mixed timezone/timezoneless values is not implemented"
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.string_agg(arg, sep)

if order_by:
sep = sge.Order(this=sep, expressions=order_by)

return sge.GroupConcat(this=arg, separator=sep)

def visit_FloorDivide(self, op, *, left, right):
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)
Expand Down Expand Up @@ -225,10 +232,10 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.array_agg(sge.IgnoreNulls(this=arg))
def visit_ArrayCollect(self, op, *, arg, where, order_by):
return sge.IgnoreNulls(
this=self.agg.array_agg(arg, where=where, order_by=order_by)
)

def _neg_idx_to_pos(self, arg, idx):
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
Expand Down Expand Up @@ -474,17 +481,25 @@ def visit_TimestampRange(self, op, *, start, stop, step):
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
)

def visit_First(self, op, *, arg, where):
def visit_First(self, op, *, arg, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_agg(
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
)
return array[self.f.safe_offset(0)]

def visit_Last(self, op, *, arg, where):
def visit_Last(self, op, *, arg, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
return array[self.f.safe_offset(0)]

Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
def aggregate(self, compiler, name, *args, where=None, order_by=()):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
Expand Down Expand Up @@ -433,7 +438,12 @@ def visit_StringSplit(self, op, *, arg, delimiter):
delimiter, self.cast(arg, dt.String(nullable=False))
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
call = self.agg.groupArray(arg, where=where)
return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep))

Expand Down
19 changes: 14 additions & 5 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataFusionCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

agg = AggGen(supports_filter=True)
agg = AggGen(supports_filter=True, supports_order_by=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
Expand Down Expand Up @@ -425,15 +425,15 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where):
def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where):
def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
Expand Down Expand Up @@ -488,3 +488,12 @@ def visit_StructColumn(self, op, *, names, values):
args.append(sge.convert(name))
args.append(value)
return self.f.named_struct(*args)

def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"DataFusion does not support order-sensitive group_concat"
)
return super().visit_GroupConcat(
op, arg=arg, sep=sep, where=where, order_by=order_by
)
Loading

0 comments on commit a18cb5d

Please sign in to comment.