Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): support order_by in order-sensitive aggregates (collect/group_concat/first/last) #9729

Merged
merged 29 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
13deed0
feat(api): support `order_by` in order-sensitive aggregates (`collect…
jcrist Jul 23, 2024
44da6f4
fix(oracle): map group_concat to listagg
cpcloud Aug 2, 2024
c7b94e4
chore: improve error message for order by
cpcloud Aug 2, 2024
7fe26d9
chore: remove invalid literal check
cpcloud Aug 2, 2024
c5c9902
test: split up expr and execution for easier debugging
cpcloud Aug 2, 2024
367c87d
chore(trino): remove redundant ArrayCollect mapping
cpcloud Aug 2, 2024
1bf89b0
chore(bigquery): remove redundant GroupConcat mapping
cpcloud Aug 2, 2024
37989d5
chore(exasol): order by the last argument
cpcloud Aug 2, 2024
624a8a3
chore(snowflake): clean up agg gen
cpcloud Aug 2, 2024
5fbb3ed
chore(postgres): remove redundant ArrayCollect mapping
cpcloud Aug 2, 2024
8730b53
chore(base): add requires_within_group flag to agg generator
cpcloud Aug 2, 2024
731069d
chore(oracle): use base group_concat impl
cpcloud Aug 2, 2024
98c73c8
chore(mysql): use base group_concat impl
cpcloud Aug 2, 2024
06f77d6
chore(mssql): use base group_concat impl
cpcloud Aug 2, 2024
62510a4
chore(clickhouse): regen snapshots
cpcloud Aug 2, 2024
01c0bb1
chore(oracle): do not use base groupconcat impl
cpcloud Aug 2, 2024
9f57f54
chore(mysql): fix groupconcat impl
cpcloud Aug 2, 2024
c00636f
chore(bigquery): fix groupconcat impl
cpcloud Aug 2, 2024
c72b655
chore(snowflake): fix groupconcat impl
cpcloud Aug 2, 2024
4db0c86
chore: ignore tch lints
cpcloud Aug 2, 2024
327a325
chore(datafusion): xfail order_by test
cpcloud Aug 2, 2024
884c828
chore(polars): groupconcat
cpcloud Aug 2, 2024
9310ee4
chore(impala): regen snapshots
cpcloud Aug 2, 2024
9156ad8
chore(exasol): fix group concat
cpcloud Aug 2, 2024
d6f346c
test(groupconcat): make it work with more backends by ordering the wi…
cpcloud Aug 2, 2024
5755925
chore(polars): fix groupconcat cardinality with sort_by
cpcloud Aug 2, 2024
df658d2
chore(risingwave): fix test failures
cpcloud Aug 2, 2024
92328c1
chore: remove unnecessary `requires_within_group` AggGen property
cpcloud Aug 2, 2024
b9f2b19
chore(risingwave): remove old first/last definitions
cpcloud Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArray("t0"."string_col"))
THEN NULL
ELSE arrayStringConcat(groupArray("t0"."string_col"), ',')
END AS "GroupConcat(string_col, ',')"
END AS "GroupConcat(string_col, ',', ())"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0))
THEN NULL
ELSE arrayStringConcat(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0), ',')
END AS "GroupConcat(string_col, ',', Equals(bool_col, 0))"
END AS "GroupConcat(string_col, ',', (), Equals(bool_col, 0))"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ SELECT
WHEN empty(groupArray("t0"."string_col"))
THEN NULL
ELSE arrayStringConcat(groupArray("t0"."string_col"), '-')
END AS "GroupConcat(string_col, '-')"
END AS "GroupConcat(string_col, '-', ())"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col)`
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col, ())`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col)`
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col, ())`
FROM `functional_alltypes` AS `t0`
21 changes: 18 additions & 3 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
plan,
)
from ibis.common.dispatch import Dispatched
from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError
from ibis.common.exceptions import (
OperationNotDefinedError,
UnboundExpressionError,
UnsupportedOperationError,
)
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import any_of, gen_name

Expand Down Expand Up @@ -253,7 +257,12 @@ def visit(
############################# Reductions ##################################

@classmethod
def visit(cls, op: ops.Reduction, arg, where):
def visit(cls, op: ops.Reduction, arg, where, order_by=()):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
func = cls.kernels.reductions[type(op)]
return cls.agg(func, arg, where)

Expand Down Expand Up @@ -344,7 +353,13 @@ def agg(df):
return agg

@classmethod
def visit(cls, op: ops.GroupConcat, arg, sep, where):
def visit(cls, op: ops.GroupConcat, arg, sep, where, order_by):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

if where is None:

def agg(df):
Expand Down
50 changes: 41 additions & 9 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _literal_value(op, nan_as_none=False):


@singledispatch
def translate(expr, *, ctx):
def translate(expr, **_):
raise NotImplementedError(expr)


Expand Down Expand Up @@ -748,6 +748,11 @@ def execute_first_last(op, **kw):

arg = arg.filter(predicate)

if order_by := getattr(op, "order_by", ()):
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
descending = [k.descending for k in order_by]
arg = arg.sort_by(keys, descending=descending)

return arg.last() if isinstance(op, ops.Last) else arg.first()


Expand Down Expand Up @@ -985,14 +990,21 @@ def array_column(op, **kw):
@translate.register(ops.ArrayCollect)
def array_collect(op, in_group_by=False, **kw):
arg = translate(op.arg, **kw)
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
out = arg.drop_nulls()
if not in_group_by:
# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
out = out.implode()
return out

predicate = arg.is_not_null()
if op.where is not None:
predicate &= translate(op.where, **kw)

arg = arg.filter(predicate)

if op.order_by:
keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by]
descending = [k.descending for k in op.order_by]
arg = arg.sort_by(keys, descending=descending)

# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
return arg if in_group_by else arg.implode()


@translate.register(ops.ArrayFlatten)
Expand Down Expand Up @@ -1390,3 +1402,23 @@ def execute_array_all(op, **kw):
arg = translate(op.arg, **kw)
no_nulls = arg.list.drop_nulls()
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.all())


@translate.register(ops.GroupConcat)
def execute_group_concat(op, **kw):
arg = translate(op.arg, **kw)
sep = _literal_value(op.sep)

predicate = arg.is_not_null()

if (where := op.where) is not None:
predicate &= translate(where, **kw)

arg = arg.filter(predicate)

if order_by := op.order_by:
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
descending = [k.descending for k in order_by]
arg = arg.sort_by(keys, descending=descending)

return pl.when(arg.count() > 0).then(arg.str.join(sep)).otherwise(None)
47 changes: 33 additions & 14 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class AggGen:
supports_filter
Whether the backend supports a FILTER clause in the aggregate.
Defaults to False.
supports_order_by
Whether the backend supports an ORDER BY clause in (relevant)
aggregates. Defaults to False.
"""

class _Accessor:
Expand All @@ -79,10 +82,13 @@ def __getattr__(self, name: str) -> Callable:

__getitem__ = __getattr__

__slots__ = ("supports_filter",)
__slots__ = ("supports_filter", "supports_order_by")

def __init__(self, *, supports_filter: bool = False):
def __init__(
self, *, supports_filter: bool = False, supports_order_by: bool = False
):
self.supports_filter = supports_filter
self.supports_order_by = supports_order_by

def __get__(self, instance, owner=None):
if instance is None:
Expand All @@ -96,6 +102,7 @@ def aggregate(
name: str,
*args: Any,
where: Any = None,
order_by: tuple = (),
):
"""Compile the specified aggregate.

Expand All @@ -109,21 +116,31 @@ def aggregate(
Any arguments to pass to the aggregate.
where
An optional column filter to apply before performing the aggregate.

order_by
Optional ordering keys to use to order the rows before performing
the aggregate.
"""
func = compiler.f[name]

if where is None:
return func(*args)

if self.supports_filter:
return sge.Filter(
this=func(*args),
expression=sge.Where(this=where),
if order_by and not self.supports_order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
f"not supported for the {compiler.dialect} backend"
)
else:

if where is not None and not self.supports_filter:
args = tuple(compiler.if_(where, arg, NULL) for arg in args)
return func(*args)

if order_by and self.supports_order_by:
*rest, last = args
out = func(*rest, sge.Order(this=last, expressions=order_by))
else:
out = func(*args)

if where is not None and self.supports_filter:
out = sge.Filter(this=out, expression=sge.Where(this=where))

return out


class VarGen:
Expand Down Expand Up @@ -424,8 +441,10 @@ def make_impl(op, target_name):

if issubclass(op, ops.Reduction):

def impl(self, _, *, _name: str = target_name, where, **kw):
return self.agg[_name](*kw.values(), where=where)
def impl(
self, _, *, _name: str = target_name, where, order_by=(), **kw
):
return self.agg[_name](*kw.values(), where=where, order_by=order_by)

else:

Expand Down
33 changes: 24 additions & 9 deletions ibis/backends/sql/compilers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
Expand All @@ -28,6 +28,9 @@
dialect = BigQuery
type_mapper = BigQueryType
udf_type_mapper = BigQueryUDFType

agg = AggGen(supports_order_by=True)

rewrites = (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -172,10 +175,14 @@
"timestamp difference with mixed timezone/timezoneless values is not implemented"
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.string_agg(arg, sep)

if order_by:
sep = sge.Order(this=sep, expressions=order_by)

Check warning on line 183 in ibis/backends/sql/compilers/bigquery.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery.py#L183

Added line #L183 was not covered by tests

return sge.GroupConcat(this=arg, separator=sep)

Check warning on line 185 in ibis/backends/sql/compilers/bigquery.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery.py#L185

Added line #L185 was not covered by tests

def visit_FloorDivide(self, op, *, left, right):
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)
Expand Down Expand Up @@ -225,10 +232,10 @@
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.array_agg(sge.IgnoreNulls(this=arg))
def visit_ArrayCollect(self, op, *, arg, where, order_by):
return sge.IgnoreNulls(

Check warning on line 236 in ibis/backends/sql/compilers/bigquery.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery.py#L236

Added line #L236 was not covered by tests
this=self.agg.array_agg(arg, where=where, order_by=order_by)
)

def _neg_idx_to_pos(self, arg, idx):
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
Expand Down Expand Up @@ -474,17 +481,25 @@
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
)

def visit_First(self, op, *, arg, where):
def visit_First(self, op, *, arg, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

Check warning on line 489 in ibis/backends/sql/compilers/bigquery.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery.py#L489

Added line #L489 was not covered by tests

array = self.f.array_agg(
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
)
return array[self.f.safe_offset(0)]

def visit_Last(self, op, *, arg, where):
def visit_Last(self, op, *, arg, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

Check warning on line 501 in ibis/backends/sql/compilers/bigquery.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery.py#L501

Added line #L501 was not covered by tests

array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
return array[self.f.safe_offset(0)]

Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
def aggregate(self, compiler, name, *args, where=None, order_by=()):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
Expand Down Expand Up @@ -433,7 +438,12 @@ def visit_StringSplit(self, op, *, arg, delimiter):
delimiter, self.cast(arg, dt.String(nullable=False))
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
call = self.agg.groupArray(arg, where=where)
return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep))

Expand Down
19 changes: 14 additions & 5 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataFusionCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

agg = AggGen(supports_filter=True)
agg = AggGen(supports_filter=True, supports_order_by=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
Expand Down Expand Up @@ -425,15 +425,15 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where):
def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where):
def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
Expand Down Expand Up @@ -488,3 +488,12 @@ def visit_StructColumn(self, op, *, names, values):
args.append(sge.convert(name))
args.append(value)
return self.f.named_struct(*args)

def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"DataFusion does not support order-sensitive group_concat"
)
return super().visit_GroupConcat(
op, arg=arg, sep=sep, where=where, order_by=order_by
)
Loading