From a18cb5d30b25f73cb990b15cd184eecfdd2c0cc6 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 2 Aug 2024 12:53:01 -0500 Subject: [PATCH] feat(api): support `order_by` in order-sensitive aggregates (`collect`/`group_concat`/`first`/`last`) (#9729) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- .../test_group_concat/comma_none/out.sql | 2 +- .../test_group_concat/comma_zero/out.sql | 2 +- .../test_group_concat/minus_none/out.sql | 2 +- .../test_analytic_exprs/first/out.sql | 2 +- .../test_analytic_exprs/last/out.sql | 2 +- ibis/backends/pandas/executor.py | 21 ++- ibis/backends/polars/compiler.py | 50 +++++- ibis/backends/sql/compilers/base.py | 47 ++++-- ibis/backends/sql/compilers/bigquery.py | 33 +++- ibis/backends/sql/compilers/clickhouse.py | 14 +- ibis/backends/sql/compilers/datafusion.py | 19 ++- ibis/backends/sql/compilers/druid.py | 4 +- ibis/backends/sql/compilers/duckdb.py | 10 +- ibis/backends/sql/compilers/exasol.py | 9 + ibis/backends/sql/compilers/flink.py | 8 +- ibis/backends/sql/compilers/mssql.py | 13 +- ibis/backends/sql/compilers/mysql.py | 11 +- ibis/backends/sql/compilers/oracle.py | 15 +- ibis/backends/sql/compilers/postgres.py | 3 +- ibis/backends/sql/compilers/pyspark.py | 19 ++- ibis/backends/sql/compilers/risingwave.py | 20 ++- ibis/backends/sql/compilers/snowflake.py | 37 +++-- ibis/backends/sql/compilers/sqlite.py | 1 + ibis/backends/sql/compilers/trino.py | 23 ++- ibis/backends/sql/dialects.py | 12 +- ibis/backends/tests/test_aggregation.py | 157 ++++++++++++++++-- ibis/backends/tests/test_generic.py | 8 +- ibis/expr/operations/reductions.py | 6 + ibis/expr/tests/test_reductions.py | 38 +++++ ibis/expr/types/generic.py | 77 +++++++-- 30 files changed, 537 insertions(+), 128 deletions(-) diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_none/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_none/out.sql index 1c85b5fbb75a..7884b90dffc4 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_none/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_none/out.sql @@ -3,5 +3,5 @@ SELECT WHEN empty(groupArray("t0"."string_col")) THEN NULL ELSE arrayStringConcat(groupArray("t0"."string_col"), ',') - END AS "GroupConcat(string_col, ',')" + END AS "GroupConcat(string_col, ',', ())" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_zero/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_zero/out.sql index 7846aad19960..431e69134db3 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_zero/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_zero/out.sql @@ -3,5 +3,5 @@ SELECT WHEN empty(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0)) THEN NULL ELSE arrayStringConcat(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0), ',') - END AS "GroupConcat(string_col, ',', Equals(bool_col, 0))" + END AS "GroupConcat(string_col, ',', (), Equals(bool_col, 0))" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/minus_none/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/minus_none/out.sql index b68d30a94cec..e47761337197 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/minus_none/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/minus_none/out.sql @@ -3,5 +3,5 @@ SELECT WHEN empty(groupArray("t0"."string_col")) THEN NULL ELSE arrayStringConcat(groupArray("t0"."string_col"), '-') - END AS "GroupConcat(string_col, '-')" + END AS "GroupConcat(string_col, '-', ())" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql index 2b5291f39629..1391a6c01c51 100644 --- a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql @@ -1,3 +1,3 @@ SELECT - FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col)` + FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col, ())` FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql index 3807d77b7e8e..904f923e17e7 100644 --- a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql @@ -1,3 +1,3 @@ SELECT - LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col)` + LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col, ())` FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 2e2aa6504e5e..f3c1cdaf4d12 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -30,7 +30,11 @@ plan, ) from ibis.common.dispatch import Dispatched -from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError +from ibis.common.exceptions import ( + OperationNotDefinedError, + UnboundExpressionError, + UnsupportedOperationError, +) from ibis.formats.pandas import PandasData, PandasType from ibis.util import any_of, gen_name @@ -253,7 +257,12 @@ def visit( ############################# Reductions ################################## @classmethod - def visit(cls, op: ops.Reduction, arg, where): + def visit(cls, op: ops.Reduction, arg, where, order_by=()): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) func = cls.kernels.reductions[type(op)] return cls.agg(func, arg, where) @@ -344,7 +353,13 @@ def agg(df): return agg @classmethod - def visit(cls, op: ops.GroupConcat, arg, sep, where): + def visit(cls, op: ops.GroupConcat, arg, sep, where, order_by): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + if where is None: def agg(df): diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 991d7a3a032f..6a9fe4a3baf9 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -45,7 +45,7 @@ def _literal_value(op, nan_as_none=False): @singledispatch -def translate(expr, *, ctx): +def translate(expr, **_): raise NotImplementedError(expr) @@ -748,6 +748,11 @@ def execute_first_last(op, **kw): arg = arg.filter(predicate) + if order_by := getattr(op, "order_by", ()): + keys = [translate(k.expr, **kw).filter(predicate) for k in order_by] + descending = [k.descending for k in order_by] + arg = arg.sort_by(keys, descending=descending) + return arg.last() if isinstance(op, ops.Last) else arg.first() @@ -985,14 +990,21 @@ def array_column(op, **kw): @translate.register(ops.ArrayCollect) def array_collect(op, in_group_by=False, **kw): arg = translate(op.arg, **kw) - if (where := op.where) is not None: - arg = arg.filter(translate(where, **kw)) - out = arg.drop_nulls() - if not in_group_by: - # Polars' behavior changes for `implode` within a `group_by` currently. - # See https://github.com/pola-rs/polars/issues/16756 - out = out.implode() - return out + + predicate = arg.is_not_null() + if op.where is not None: + predicate &= translate(op.where, **kw) + + arg = arg.filter(predicate) + + if op.order_by: + keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by] + descending = [k.descending for k in op.order_by] + arg = arg.sort_by(keys, descending=descending) + + # Polars' behavior changes for `implode` within a `group_by` currently. + # See https://github.com/pola-rs/polars/issues/16756 + return arg if in_group_by else arg.implode() @translate.register(ops.ArrayFlatten) @@ -1390,3 +1402,23 @@ def execute_array_all(op, **kw): arg = translate(op.arg, **kw) no_nulls = arg.list.drop_nulls() return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.all()) + + +@translate.register(ops.GroupConcat) +def execute_group_concat(op, **kw): + arg = translate(op.arg, **kw) + sep = _literal_value(op.sep) + + predicate = arg.is_not_null() + + if (where := op.where) is not None: + predicate &= translate(where, **kw) + + arg = arg.filter(predicate) + + if order_by := op.order_by: + keys = [translate(k.expr, **kw).filter(predicate) for k in order_by] + descending = [k.descending for k in order_by] + arg = arg.sort_by(keys, descending=descending) + + return pl.when(arg.count() > 0).then(arg.str.join(sep)).otherwise(None) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 78d5777bdc04..e66aeecc8245 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -63,6 +63,9 @@ class AggGen: supports_filter Whether the backend supports a FILTER clause in the aggregate. Defaults to False. + supports_order_by + Whether the backend supports an ORDER BY clause in (relevant) + aggregates. Defaults to False. """ class _Accessor: @@ -79,10 +82,13 @@ def __getattr__(self, name: str) -> Callable: __getitem__ = __getattr__ - __slots__ = ("supports_filter",) + __slots__ = ("supports_filter", "supports_order_by") - def __init__(self, *, supports_filter: bool = False): + def __init__( + self, *, supports_filter: bool = False, supports_order_by: bool = False + ): self.supports_filter = supports_filter + self.supports_order_by = supports_order_by def __get__(self, instance, owner=None): if instance is None: @@ -96,6 +102,7 @@ def aggregate( name: str, *args: Any, where: Any = None, + order_by: tuple = (), ): """Compile the specified aggregate. @@ -109,21 +116,31 @@ def aggregate( Any arguments to pass to the aggregate. where An optional column filter to apply before performing the aggregate. - + order_by + Optional ordering keys to use to order the rows before performing + the aggregate. """ func = compiler.f[name] - if where is None: - return func(*args) - - if self.supports_filter: - return sge.Filter( - this=func(*args), - expression=sge.Where(this=where), + if order_by and not self.supports_order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + f"not supported for the {compiler.dialect} backend" ) - else: + + if where is not None and not self.supports_filter: args = tuple(compiler.if_(where, arg, NULL) for arg in args) - return func(*args) + + if order_by and self.supports_order_by: + *rest, last = args + out = func(*rest, sge.Order(this=last, expressions=order_by)) + else: + out = func(*args) + + if where is not None and self.supports_filter: + out = sge.Filter(this=out, expression=sge.Where(this=where)) + + return out class VarGen: @@ -424,8 +441,10 @@ def make_impl(op, target_name): if issubclass(op, ops.Reduction): - def impl(self, _, *, _name: str = target_name, where, **kw): - return self.agg[_name](*kw.values(), where=where) + def impl( + self, _, *, _name: str = target_name, where, order_by=(), **kw + ): + return self.agg[_name](*kw.values(), where=where, order_by=order_by) else: diff --git a/ibis/backends/sql/compilers/bigquery.py b/ibis/backends/sql/compilers/bigquery.py index 77fd5f6bdb24..013600796c10 100644 --- a/ibis/backends/sql/compilers/bigquery.py +++ b/ibis/backends/sql/compilers/bigquery.py @@ -12,7 +12,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, @@ -28,6 +28,9 @@ class BigQueryCompiler(SQLGlotCompiler): dialect = BigQuery type_mapper = BigQueryType udf_type_mapper = BigQueryUDFType + + agg = AggGen(supports_order_by=True) + rewrites = ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, @@ -172,10 +175,14 @@ def visit_TimestampDelta(self, op, *, left, right, part): "timestamp difference with mixed timezone/timezoneless values is not implemented" ) - def visit_GroupConcat(self, op, *, arg, sep, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): if where is not None: arg = self.if_(where, arg, NULL) - return self.f.string_agg(arg, sep) + + if order_by: + sep = sge.Order(this=sep, expressions=order_by) + + return sge.GroupConcat(this=arg, separator=sep) def visit_FloorDivide(self, op, *, left, right): return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) @@ -225,10 +232,10 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - def visit_ArrayCollect(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return self.f.array_agg(sge.IgnoreNulls(this=arg)) + def visit_ArrayCollect(self, op, *, arg, where, order_by): + return sge.IgnoreNulls( + this=self.agg.array_agg(arg, where=where, order_by=order_by) + ) def _neg_idx_to_pos(self, arg, idx): return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) @@ -474,17 +481,25 @@ def visit_TimestampRange(self, op, *, start, stop, step): self.f.generate_timestamp_array, start, stop, step, op.step.dtype ) - def visit_First(self, op, *, arg, where): + def visit_First(self, op, *, arg, where, order_by): if where is not None: arg = self.if_(where, arg, NULL) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + array = self.f.array_agg( sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)), ) return array[self.f.safe_offset(0)] - def visit_Last(self, op, *, arg, where): + def visit_Last(self, op, *, arg, where, order_by): if where is not None: arg = self.if_(where, arg, NULL) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) return array[self.f.safe_offset(0)] diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index e665e2dd6fac..66cc9a421e58 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -20,7 +20,12 @@ class ClickhouseAggGen(AggGen): - def aggregate(self, compiler, name, *args, where=None): + def aggregate(self, compiler, name, *args, where=None, order_by=()): + if order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) # Clickhouse aggregate functions all have filtering variants with a # `If` suffix (e.g. `SumIf` instead of `Sum`). if where is not None: @@ -433,7 +438,12 @@ def visit_StringSplit(self, op, *, arg, delimiter): delimiter, self.cast(arg, dt.String(nullable=False)) ) - def visit_GroupConcat(self, op, *, arg, sep, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) call = self.agg.groupArray(arg, where=where) return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep)) diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index e629d7945f88..e5aa459fcdcb 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -30,7 +30,7 @@ class DataFusionCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) - agg = AggGen(supports_filter=True) + agg = AggGen(supports_filter=True, supports_order_by=True) UNSUPPORTED_OPS = ( ops.ArgMax, @@ -425,15 +425,15 @@ def visit_StringConcat(self, op, *, arg): sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) ) - def visit_First(self, op, *, arg, where): + def visit_First(self, op, *, arg, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.agg.first_value(arg, where=where) + return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where): + def visit_Last(self, op, *, arg, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.agg.last_value(arg, where=where) + return self.agg.last_value(arg, where=where, order_by=order_by) def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" @@ -488,3 +488,12 @@ def visit_StructColumn(self, op, *, names, values): args.append(sge.convert(name)) args.append(value) return self.f.named_struct(*args) + + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if order_by: + raise com.UnsupportedOperationError( + "DataFusion does not support order-sensitive group_concat" + ) + return super().visit_GroupConcat( + op, arg=arg, sep=sep, where=where, order_by=order_by + ) diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index fd7bed49dff1..cf876aacc766 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -121,8 +121,8 @@ def visit_Pi(self, op): def visit_Sign(self, op, *, arg): return self.if_(arg.eq(0), 0, self.if_(arg > 0, 1, -1)) - def visit_GroupConcat(self, op, *, arg, sep, where): - return self.agg.string_agg(arg, sep, 1 << 20, where=where) + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + return self.agg.string_agg(arg, sep, 1 << 20, where=where, order_by=order_by) def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index bcd45c82f28b..e47e1bafe8ce 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -34,7 +34,7 @@ class DuckDBCompiler(SQLGlotCompiler): dialect = DuckDB type_mapper = DuckDBType - agg = AggGen(supports_filter=True) + agg = AggGen(supports_filter=True, supports_order_by=True) rewrites = ( exclude_nulls_from_array_collect, @@ -476,15 +476,15 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement): arg, pattern, replacement, "g", dialect=self.dialect ) - def visit_First(self, op, *, arg, where): + def visit_First(self, op, *, arg, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.agg.first(arg, where=where) + return self.agg.first(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where): + def visit_Last(self, op, *, arg, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.agg.last(arg, where=where) + return self.agg.last(arg, where=where, order_by=order_by) def visit_Quantile(self, op, *, arg, quantile, where): suffix = "cont" if op.arg.dtype.is_numeric() else "disc" diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index e5e67c1d5e82..d6d8fbb4e279 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -128,6 +128,15 @@ def visit_NonNullLiteral(self, op, *, value, dtype): def visit_Date(self, op, *, arg): return self.cast(arg, dt.date) + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if where is not None: + arg = self.if_(where, arg, NULL) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + return sge.GroupConcat(this=arg, separator=sep) + def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 098123f1a33a..462e496428af 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -19,7 +19,13 @@ class FlinkAggGen(AggGen): - def aggregate(self, compiler, name, *args, where=None): + def aggregate(self, compiler, name, *args, where=None, order_by=()): + if order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + func = compiler.f[name] if where is not None: # Flink does support FILTER, but it's broken for: diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 4da9b31e2c5f..31f67b52aeb5 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -13,6 +13,7 @@ NULL, STAR, TRUE, + AggGen, SQLGlotCompiler, ) from ibis.backends.sql.datatypes import MSSQLType @@ -52,6 +53,8 @@ def rewrite_rows_range_order_by_window(_, **kwargs): class MSSQLCompiler(SQLGlotCompiler): __slots__ = () + agg = AggGen(supports_order_by=True) + dialect = MSSQL type_mapper = MSSQLType rewrites = ( @@ -185,10 +188,16 @@ def visit_Substring(self, op, *, arg, start, length): length = self.f.length(arg) return self.f.substring(arg, start, length) - def visit_GroupConcat(self, op, *, arg, sep, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): if where is not None: arg = self.if_(where, arg, NULL) - return self.f.group_concat(arg, sep) + + out = self.f.group_concat(arg, sep) + + if order_by: + out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) + + return out def visit_CountStar(self, op, *, arg, where): if where is not None: diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index bbf616d4a2ec..c9278910a891 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -165,14 +165,19 @@ def visit_CountDistinctStar(self, op, *, arg, where): sge.Distinct(expressions=list(map(func, op.arg.schema.keys()))) ) - def visit_GroupConcat(self, op, *, arg, sep, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): if not isinstance(op.sep, ops.Literal): raise com.UnsupportedOperationError( "Only string literal separators are supported" ) + if where is not None: - arg = self.if_(where, arg) - return self.f.group_concat(arg, sep) + arg = self.if_(where, arg, NULL) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + return sge.GroupConcat(this=arg, separator=sep) def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.dayofweek(arg) + 5) % 7 diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 2b77d052c96f..a9d1d033627a 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -6,7 +6,7 @@ import ibis.common.exceptions as com import ibis.expr.operations as ops -from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import OracleType from ibis.backends.sql.dialects import Oracle from ibis.backends.sql.rewrites import ( @@ -23,6 +23,8 @@ class OracleCompiler(SQLGlotCompiler): __slots__ = () + agg = AggGen(supports_order_by=True) + dialect = Oracle type_mapper = OracleType rewrites = ( @@ -447,8 +449,13 @@ def visit_StringConcat(self, op, *, arg): def visit_ExtractIsoYear(self, op, *, arg): return self.cast(self.f.to_char(arg, "IYYY"), op.dtype) - def visit_GroupConcat(self, op, *, arg, where, sep): + def visit_GroupConcat(self, op, *, arg, where, sep, order_by): if where is not None: - arg = self.if_(where, arg) + arg = self.if_(where, arg, NULL) + + out = self.f.listagg(arg, sep) + + if order_by: + out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) - return self.f.listagg(arg, sep) + return out diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 78338c9a9905..7386e33e7ed8 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -29,7 +29,7 @@ class PostgresCompiler(SQLGlotCompiler): rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites) - agg = AggGen(supports_filter=True) + agg = AggGen(supports_filter=True, supports_order_by=True) NAN = sge.Literal.number("'NaN'::double precision") POS_INF = sge.Literal.number("'Inf'::double precision") @@ -43,7 +43,6 @@ class PostgresCompiler(SQLGlotCompiler): SIMPLE_OPS = { ops.Arbitrary: "first", # could use any_value for postgres>=16 - ops.ArrayCollect: "array_agg", ops.ArrayRemove: "array_remove", ops.BitAnd: "bit_and", ops.BitOr: "bit_or", diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index c91d5ab4d26a..1ac1ad6553d9 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -240,15 +240,11 @@ def visit_FirstValue(self, op, *, arg): def visit_LastValue(self, op, *, arg): return sge.IgnoreNulls(this=self.f.last(arg)) - def visit_First(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return sge.IgnoreNulls(this=self.f.first(arg)) + def visit_First(self, op, *, arg, where, order_by): + return sge.IgnoreNulls(this=self.agg.first(arg, where=where, order_by=order_by)) - def visit_Last(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return sge.IgnoreNulls(this=self.f.last(arg)) + def visit_Last(self, op, *, arg, where, order_by): + return sge.IgnoreNulls(this=self.agg.last(arg, where=where, order_by=order_by)) def visit_Arbitrary(self, op, *, arg, where): # For Spark>=3.4 we could use any_value here @@ -259,7 +255,12 @@ def visit_Arbitrary(self, op, *, arg, where): def visit_Median(self, op, *, arg, where): return self.agg.percentile(arg, 0.5, where=where) - def visit_GroupConcat(self, op, *, arg, sep, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) if where is not None: arg = self.if_(where, arg, NULL) collected = self.f.collect_list(arg) diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 3d626ffbe68e..8d1e86d1ce5f 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -35,6 +35,20 @@ class RisingWaveCompiler(PostgresCompiler): def visit_DateNow(self, op): return self.cast(sge.CurrentTimestamp(), dt.date) + def visit_First(self, op, *, arg, where, order_by): + if not order_by: + raise com.UnsupportedOperationError( + "RisingWave requires an `order_by` be specified in `first`" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by): + if not order_by: + raise com.UnsupportedOperationError( + "RisingWave requires an `order_by` be specified in `last`" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -44,12 +58,6 @@ def visit_Correlation(self, op, *, left, right, how, where): op, left=left, right=right, how=how, where=where ) - def visit_First(self, op, *, arg, where): - return self.agg.first_value(arg, where=where) - - def visit_Last(self, op, *, arg, where): - return self.agg.last_value(arg, where=where) - def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index c651d9c9d816..00af8fc01a27 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -10,7 +10,14 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compilers.base import NULL, STAR, C, FuncGen, SQLGlotCompiler +from ibis.backends.sql.compilers.base import ( + NULL, + STAR, + AggGen, + C, + FuncGen, + SQLGlotCompiler, +) from ibis.backends.sql.datatypes import SnowflakeType from ibis.backends.sql.dialects import Snowflake from ibis.backends.sql.rewrites import ( @@ -32,6 +39,9 @@ class SnowflakeCompiler(SQLGlotCompiler): dialect = Snowflake type_mapper = SnowflakeType no_limit_value = NULL + + agg = AggGen(supports_order_by=True) + rewrites = ( exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, @@ -351,22 +361,23 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def visit_First(self, op, *, arg, where): - return self.f.get(self.agg.array_agg(arg, where=where), 0) + def visit_First(self, op, *, arg, where, order_by): + return self.f.get(self.agg.array_agg(arg, where=where, order_by=order_by), 0) - def visit_Last(self, op, *, arg, where): - expr = self.agg.array_agg(arg, where=where) + def visit_Last(self, op, *, arg, where, order_by): + expr = self.agg.array_agg(arg, where=where, order_by=order_by) return self.f.get(expr, self.f.array_size(expr) - 1) - def visit_GroupConcat(self, op, *, arg, where, sep): - if where is None: - return self.f.listagg(arg, sep) + def visit_GroupConcat(self, op, *, arg, where, sep, order_by): + if where is not None: + arg = self.if_(where, arg, NULL) - return self.if_( - self.f.count_if(where) > 0, - self.f.listagg(self.if_(where, arg, NULL), sep), - NULL, - ) + out = self.f.listagg(arg, sep) + + if order_by: + out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) + + return out def visit_TimestampBucket(self, op, *, arg, interval, offset): if offset is not None: diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 6067d0a2557a..f8f21219a731 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -20,6 +20,7 @@ class SQLiteCompiler(SQLGlotCompiler): dialect = SQLite type_mapper = SQLiteType + # We could set `supports_order_by=True` for SQLite >= 3.44.0 (2023-11-01). agg = AggGen(supports_filter=True) NAN = NULL diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 6d87c7199fca..8ae7eb2eaa2f 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -34,7 +34,7 @@ class TrinoCompiler(SQLGlotCompiler): dialect = Trino type_mapper = TrinoType - agg = AggGen(supports_filter=True) + agg = AggGen(supports_filter=True, supports_order_by=True) rewrites = ( exclude_nulls_from_array_collect, @@ -85,7 +85,6 @@ class TrinoCompiler(SQLGlotCompiler): ops.ArraySort: "array_sort", ops.ArrayDistinct: "array_distinct", ops.ArrayLength: "cardinality", - ops.ArrayCollect: "array_agg", ops.ArrayIntersect: "array_intersect", ops.BitAnd: "bitwise_and_agg", ops.BitOr: "bitwise_or_agg", @@ -370,15 +369,27 @@ def visit_StringAscii(self, op, *, arg): def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - def visit_First(self, op, *, arg, where): + def visit_First(self, op, *, arg, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.f.element_at(self.agg.array_agg(arg, where=where), 1) + return self.f.element_at( + self.agg.array_agg(arg, where=where, order_by=order_by), 1 + ) + + def visit_Last(self, op, *, arg, where, order_by): + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.f.element_at( + self.agg.array_agg(arg, where=where, order_by=order_by), -1 + ) - def visit_Last(self, op, *, arg, where): + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.f.element_at(self.agg.array_agg(arg, where=where), -1) + array = self.agg.array_agg( + self.cast(arg, dt.string), where=where, order_by=order_by + ) + return self.f.array_join(array, sep) def visit_ArrayZip(self, op, *, arg): max_zip_arguments = 5 diff --git a/ibis/backends/sql/dialects.py b/ibis/backends/sql/dialects.py index daa52eecedf9..a4920a4b5a96 100644 --- a/ibis/backends/sql/dialects.py +++ b/ibis/backends/sql/dialects.py @@ -64,9 +64,18 @@ def _interval(self, e, quote_arg=True): return f"INTERVAL {arg} {e.args['unit']}" +def _group_concat(self, e): + this = self.sql(e, "this") + separator = self.sql(e, "separator") or "','" + return f"GROUP_CONCAT({this} SEPARATOR {separator})" + + class Exasol(Postgres): class Generator(Postgres.Generator): - TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | {sge.Interval: _interval} + TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { + sge.Interval: _interval, + sge.GroupConcat: _group_concat, + } TYPE_MAPPING = Postgres.Generator.TYPE_MAPPING.copy() | { sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP WITH LOCAL TIME ZONE", } @@ -349,6 +358,7 @@ def _create_sql(self, expression: sge.Create) -> str: sge.ApproxDistinct: rename_func("approx_count_distinct"), sge.Create: _create_sql, sge.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), + sge.GroupConcat: rename_func("listagg"), } # TODO: can delete this after bumping sqlglot version > 20.9.0 diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 63f5a531094c..992251c6b90f 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -622,7 +622,11 @@ def test_reduction_ops( ["druid", "impala", "mssql", "mysql", "oracle"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError) +@pytest.mark.notimpl( + ["risingwave"], + raises=com.UnsupportedOperationError, + reason="risingwave requires an `order_by` for these aggregations", +) @pytest.mark.parametrize("method", ["first", "last"]) @pytest.mark.parametrize( "filtered", @@ -663,6 +667,52 @@ def test_first_last(backend, alltypes, method, filtered): assert res == 30 +@pytest.mark.notimpl( + [ + "clickhouse", + "dask", + "exasol", + "flink", + "pandas", + "pyspark", + "sqlite", + ], + raises=com.UnsupportedOperationError, +) +@pytest.mark.notimpl( + ["druid", "impala", "mssql", "mysql", "oracle"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.parametrize("method", ["first", "last"]) +@pytest.mark.parametrize( + "filtered", + [ + param( + False, + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=Exception, + reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + ) + ], + ), + True, + ], +) +def test_first_last_ordered(backend, alltypes, method, filtered): + t = alltypes.mutate(new=alltypes.int_col.nullif(0).nullif(9)) + where = None + sol = 1 if method == "last" else 8 + if filtered: + where = _.int_col != sol + sol = 2 if method == "last" else 7 + + expr = getattr(t.new, method)(where=where, order_by=t.int_col.desc()) + res = expr.execute() + assert res == sol + + @pytest.mark.notimpl( [ "impala", @@ -1207,6 +1257,11 @@ def test_date_quantile(alltypes): raises=GoogleBadRequest, reason="Argument 2 to STRING_AGG must be a literal or query parameter", ), + pytest.mark.notimpl( + ["polars"], + raises=com.UnsupportedArgumentError, + reason="polars doesn't support expression separators", + ), ], ), ], @@ -1218,23 +1273,15 @@ def test_date_quantile(alltypes): param( lambda t: t.string_col.isin(["1", "7"]), lambda t: t.string_col.isin(["1", "7"]), - marks=[ - pytest.mark.notyet(["trino"], raises=TrinoUserError), - ], id="is_in", ), param( lambda t: t.string_col.notin(["1", "7"]), lambda t: ~t.string_col.isin(["1", "7"]), - marks=[ - pytest.mark.notyet(["trino"], raises=TrinoUserError), - ], id="not_in", ), ], ) -@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @pytest.mark.notyet(["flink"], raises=Py4JJavaError) def test_group_concat( backend, alltypes, df, ibis_cond, pandas_cond, ibis_sep, pandas_sep @@ -1267,6 +1314,89 @@ def test_group_concat( ) +@pytest.mark.notimpl( + [ + "clickhouse", + "datafusion", + "dask", + "druid", + "flink", + "impala", + "pandas", + "pyspark", + "sqlite", + ], + raises=com.UnsupportedOperationError, +) +@pytest.mark.parametrize("filtered", [False, True]) +def test_group_concat_ordered(alltypes, df, filtered): + ibis_cond = (_.id % 13 == 0) if filtered else None + pd_cond = (df.id % 13 == 0) if filtered else True + expr = ( + alltypes.filter(_.bigint_col == 10) + .id.cast("str") + .group_concat(":", where=ibis_cond, order_by=_.id.desc()) + ) + result = expr.execute() + expected = ":".join( + df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str) + ) + assert result == expected + + +@pytest.mark.notimpl( + [ + "druid", + "exasol", + "flink", + "impala", + "mssql", + "mysql", + "oracle", + "sqlite", + ], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl( + [ + "clickhouse", + "dask", + "pandas", + "pyspark", + ], + raises=com.UnsupportedOperationError, +) +@pytest.mark.parametrize( + "filtered", + [ + param( + True, + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=Exception, + reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + ) + ], + ), + False, + ], +) +def test_collect_ordered(alltypes, df, filtered): + ibis_cond = (_.id % 13 == 0) if filtered else None + pd_cond = (df.id % 13 == 0) if filtered else True + result = ( + alltypes.filter(_.bigint_col == 10) + .id.cast("str") + .collect(where=ibis_cond, order_by=_.id.desc()) + .execute() + ) + expected = list( + df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str) + ) + assert result == expected + + @pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) def test_topk_op(alltypes, df): # TopK expression will order rows by "count" but each backend @@ -1489,7 +1619,6 @@ def test_grouped_case(backend, con): @pytest.mark.notimpl( ["datafusion"], raises=Exception, reason="not supported in datafusion" ) -@pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @pytest.mark.notyet(["flink"], raises=Py4JJavaError) @pytest.mark.notyet(["impala"], raises=ImpalaHiveServer2Error) @pytest.mark.notyet(["clickhouse"], raises=ClickHouseDatabaseError) @@ -1508,13 +1637,17 @@ def test_group_concat_over_window(backend, con): "s": ["a|b|c", "b|a|c", "b|b|b|c|a"], "token": ["a", "b", "c"], "pk": [1, 1, 2], + "id": [1, 2, 3], } ) expected = input_df.assign(test=["a|b|c|b|a|c", "b|a|c", "b|b|b|c|a"]) table = ibis.memtable(input_df) - w = ibis.window(group_by="pk", preceding=0, following=None) - expr = table.mutate(test=table.s.group_concat(sep="|").over(w)).order_by("pk") + expr = table.mutate( + test=table.s.group_concat(sep="|").over( + group_by="pk", order_by="id", rows=(0, None) + ) + ).order_by("id") result = con.execute(expr) backend.assert_frame_equal(result, expected) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 6f760a1dcbbd..8026ea3e9606 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1510,8 +1510,8 @@ def test_pivot_wider(backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=PsycoPg2InternalError, - reason="function last(double precision) does not exist, do you mean left or least", + raises=com.UnsupportedOperationError, + reason="first/last requires an order_by", ) @pytest.mark.notyet( ["datafusion"], @@ -1575,8 +1575,8 @@ def test_distinct_on_keep(backend, on, keep): ) @pytest.mark.notimpl( ["risingwave"], - raises=PsycoPg2InternalError, - reason="function first(double precision) does not exist", + raises=com.UnsupportedOperationError, + reason="first/last requires an order_by", ) @pytest.mark.notyet( ["datafusion"], diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index b3433897dff2..0e093f536b45 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -10,8 +10,10 @@ import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute +from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Column, Value from ibis.expr.operations.relations import Relation # noqa: TCH001 +from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 @public @@ -78,6 +80,7 @@ class First(Filterable, Reduction): """Retrieve the first element.""" arg: Column[dt.Any] + order_by: VarTuple[SortKey] = () dtype = rlz.dtype_like("arg") @@ -87,6 +90,7 @@ class Last(Filterable, Reduction): """Retrieve the last element.""" arg: Column[dt.Any] + order_by: VarTuple[SortKey] = () dtype = rlz.dtype_like("arg") @@ -344,6 +348,7 @@ class GroupConcat(Filterable, Reduction): arg: Column sep: Value[dt.String] + order_by: VarTuple[SortKey] = () dtype = dt.string @@ -362,6 +367,7 @@ class ArrayCollect(Filterable, Reduction): """Collect values into an array.""" arg: Column + order_by: VarTuple[SortKey] = () @attribute def dtype(self): diff --git a/ibis/expr/tests/test_reductions.py b/ibis/expr/tests/test_reductions.py index 5f27f25c4a7b..701a79161508 100644 --- a/ibis/expr/tests/test_reductions.py +++ b/ibis/expr/tests/test_reductions.py @@ -7,6 +7,7 @@ import ibis.expr.operations as ops from ibis import _ from ibis.common.deferred import Deferred +from ibis.common.exceptions import IbisTypeError @pytest.mark.parametrize( @@ -113,3 +114,40 @@ def test_cov_corr_deferred(func_name): t = ibis.table({"a": "int", "b": "int"}, name="t") func = getattr(t.a, func_name) assert func(_.b).equals(func(t.b)) + + +@pytest.mark.parametrize("method", ["collect", "first", "last", "group_concat"]) +def test_ordered_aggregations(method): + t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t") + func = getattr(t.a, method) + + q1 = func(order_by="b") + q2 = func(order_by=("b",)) + q3 = func(order_by=_.b) + q4 = func(order_by=t.b) + assert q1.equals(q2) + assert q1.equals(q3) + assert q1.equals(q4) + + q5 = func(order_by=("b", "c")) + q6 = func(order_by=(_.b, _.c)) + assert q5.equals(q6) + + q7 = func(order_by=_.b.desc()) + q8 = func(order_by=t.b.desc()) + assert q7.equals(q8) + + with pytest.raises(IbisTypeError): + func(order_by="oops") + + +@pytest.mark.parametrize("method", ["collect", "first", "last", "group_concat"]) +def test_ordered_aggregations_no_order(method): + t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t") + func = getattr(t.a, method) + + q1 = func() + q2 = func(order_by=None) + q3 = func(order_by=()) + assert q1.equals(q2) + assert q1.equals(q3) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index d9b087fbd4f1..473b74fc5d4c 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -15,7 +15,7 @@ from ibis.expr.rewrites import rewrite_window_input from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin, _is_null_literal from ibis.expr.types.pretty import to_rich -from ibis.util import deprecated, warn_deprecated +from ibis.util import deprecated, promote_list, warn_deprecated if TYPE_CHECKING: import pandas as pd @@ -1017,7 +1017,9 @@ def cases( builder = builder.when(case, result) return builder.else_(default).end() - def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: + def collect( + self, where: ir.BooleanValue | None = None, order_by: Any = None + ) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. This function is called `array_agg`, `list_agg`, or `list` in other systems. @@ -1025,7 +1027,12 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: Parameters ---------- where - Filter to apply before aggregation + An optional filter expression. If provided, only rows where `where` + is `True` will be included in the aggregate. + order_by + An ordering key (or keys) to use to order the rows before + aggregating. If not provided, the order of the items in the result + is undefined and backend specific. Returns ------- @@ -1082,7 +1089,11 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: │ b │ [4, 5] │ └────────┴──────────────────────┘ """ - return ops.ArrayCollect(self, where=self._bind_to_parent_table(where)).to_expr() + return ops.ArrayCollect( + self, + where=self._bind_to_parent_table(where), + order_by=self._bind_order_by(order_by), + ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: """Return whether this expression is identical to other. @@ -1119,15 +1130,21 @@ def group_concat( self, sep: str = ",", where: ir.BooleanValue | None = None, + order_by: Any = None, ) -> ir.StringScalar: """Concatenate values using the indicated separator to produce a string. Parameters ---------- sep - Separator will be used to join strings + The separator to use to join strings. where - Filter expression + An optional filter expression. If provided, only rows where `where` + is `True` will be included in the aggregate. + order_by + An ordering key (or keys) to use to order the rows before + aggregating. If not provided, the order of the items in the result + is undefined and backend specific. Returns ------- @@ -1167,7 +1184,10 @@ def group_concat( └──────────────┘ """ return ops.GroupConcat( - self, sep=sep, where=self._bind_to_parent_table(where) + self, + sep=sep, + where=self._bind_to_parent_table(where), + order_by=self._bind_order_by(order_by), ).to_expr() def __hash__(self) -> int: @@ -1507,6 +1527,11 @@ def as_table(self) -> ir.Table: "base table references to a projection" ) + def _bind_order_by(self, value) -> tuple[ops.SortKey, ...]: + if value is None: + return () + return tuple(self._bind_to_parent_table(v) for v in promote_list(value)) + def _bind_to_parent_table(self, value) -> Value | None: """Bind an expr to the parent table of `self`.""" if value is None: @@ -2083,9 +2108,21 @@ def value_counts(self) -> ir.Table: metric = _.count().name(f"{name}_count") return self.as_table().group_by(name).aggregate(metric) - def first(self, where: ir.BooleanValue | None = None) -> Value: + def first( + self, where: ir.BooleanValue | None = None, order_by: Any = None + ) -> Value: """Return the first value of a column. + Parameters + ---------- + where + An optional filter expression. If provided, only rows where `where` + is `True` will be included in the aggregate. + order_by + An ordering key (or keys) to use to order the rows before + aggregating. If not provided, the meaning of `first` is undefined + and will be backend specific. + Examples -------- >>> import ibis @@ -2111,11 +2148,25 @@ def first(self, where: ir.BooleanValue | None = None) -> Value: │ 'b' │ └─────┘ """ - return ops.First(self, where=self._bind_to_parent_table(where)).to_expr() + return ops.First( + self, + where=self._bind_to_parent_table(where), + order_by=self._bind_order_by(order_by), + ).to_expr() - def last(self, where: ir.BooleanValue | None = None) -> Value: + def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Value: """Return the last value of a column. + Parameters + ---------- + where + An optional filter expression. If provided, only rows where `where` + is `True` will be included in the aggregate. + order_by + An ordering key (or keys) to use to order the rows before + aggregating. If not provided, the meaning of `last` is undefined + and will be backend specific. + Examples -------- >>> import ibis @@ -2141,7 +2192,11 @@ def last(self, where: ir.BooleanValue | None = None) -> Value: │ 'c' │ └─────┘ """ - return ops.Last(self, where=self._bind_to_parent_table(where)).to_expr() + return ops.Last( + self, + where=self._bind_to_parent_table(where), + order_by=self._bind_order_by(order_by), + ).to_expr() def rank(self) -> ir.IntegerColumn: """Compute position of first element within each equal-value group in sorted order.