From 8c01980ed5a30912ea35fa47903347ebf44c8b12 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 20 Apr 2023 08:05:08 -0400 Subject: [PATCH] feat(api): add first/last reduction APIs BREAKING CHANGE: `Column.first()`/`Column.last()` are now reductions by default. Code running these expressions in isolation will no longer be windowed over the entire table. Code using this function in `select`-based APIs should function unchanged. --- ibis/backends/base/sql/alchemy/registry.py | 16 +++--- ibis/backends/base/sql/registry/window.py | 20 ++++---- ibis/backends/bigquery/registry.py | 24 +++++++++ ibis/backends/clickhouse/compiler/values.py | 7 ++- ibis/backends/dask/execution/reductions.py | 11 ++++ ibis/backends/duckdb/registry.py | 2 + .../test_analytic_exprs/first/out.sql | 2 +- .../test_analytic_exprs/last/out.sql | 2 +- .../impala/tests/test_analytic_functions.py | 5 +- ibis/backends/pandas/execution/generic.py | 36 +++++++++++++ .../pandas/tests/execution/test_window.py | 8 +-- ibis/backends/polars/compiler.py | 2 + ibis/backends/postgres/registry.py | 2 + .../backends/postgres/tests/test_functions.py | 21 ++++---- ibis/backends/pyspark/compiler.py | 29 ++++++++--- ibis/backends/snowflake/registry.py | 24 ++++++--- ibis/backends/sqlite/registry.py | 6 +++ .../test_count_on_order_by/out.sql | 30 +++++++++++ ibis/backends/sqlite/tests/test_functions.py | 7 ++- ibis/backends/tests/test_aggregation.py | 26 ++++++++-- ibis/backends/tests/test_generic.py | 2 +- ibis/backends/trino/registry.py | 7 +++ ibis/expr/analysis.py | 21 +++++++- ibis/expr/operations/analytic.py | 4 ++ ibis/expr/operations/reductions.py | 41 +++++++++++++++ ibis/expr/types/generic.py | 50 ++++++++++++++++--- 26 files changed, 336 insertions(+), 69 deletions(-) create mode 100644 ibis/backends/sqlite/tests/snapshots/test_functions/test_count_on_order_by/out.sql diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index a171fdc1123d..6a6e27aa2424 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -334,16 +334,18 @@ def _translate_window_boundary(boundary): def _window_function(t, window): - if isinstance(window.func, ops.CumulativeOp): - func = _cumulative_to_window(t, window.func, window.frame).op() + func = window.func.__window_op__ + + if isinstance(func, ops.CumulativeOp): + func = _cumulative_to_window(t, func, window.frame).op() return t.translate(func) - reduction = t.translate(window.func) + reduction = t.translate(func) # Some analytic functions need to have the expression of interest in # the ORDER BY part of the window clause - if isinstance(window.func, t._require_order_by) and not window.frame.order_by: - order_by = t.translate(window.func.arg) # .args[0]) + if isinstance(func, t._require_order_by) and not window.frame.order_by: + order_by = t.translate(func.arg) # .args[0]) else: order_by = [t.translate(arg) for arg in window.frame.order_by] @@ -361,7 +363,7 @@ def _window_function(t, window): else: raise NotImplementedError(type(window.frame)) - if t._forbids_frame_clause and isinstance(window.func, t._forbids_frame_clause): + if t._forbids_frame_clause and isinstance(func, t._forbids_frame_clause): # some functions on some backends don't support frame clauses additional_params = {} else: @@ -373,7 +375,7 @@ def _window_function(t, window): reduction, partition_by=partition_by, order_by=order_by, **additional_params ) - if isinstance(window.func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)): + if isinstance(func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)): return result - 1 else: return result diff --git a/ibis/backends/base/sql/registry/window.py b/ibis/backends/base/sql/registry/window.py index 2ed7593ae6e7..37d54fe0f31a 100644 --- a/ibis/backends/base/sql/registry/window.py +++ b/ibis/backends/base/sql/registry/window.py @@ -128,20 +128,22 @@ def window(translator, op): ops.ApproxCountDistinct, ) - if isinstance(op.func, _unsupported_reductions): + func = op.func.__window_op__ + + if isinstance(func, _unsupported_reductions): raise com.UnsupportedOperationError( - f'{type(op.func)} is not supported in window functions' + f'{type(func)} is not supported in window functions' ) - if isinstance(op.func, ops.CumulativeOp): - arg = cumulative_to_window(translator, op.func, op.frame) + if isinstance(func, ops.CumulativeOp): + arg = cumulative_to_window(translator, func, op.frame) return translator.translate(arg) # Some analytic functions need to have the expression of interest in # the ORDER BY part of the window clause frame = op.frame - if isinstance(op.func, translator._require_order_by) and not frame.order_by: - frame = frame.copy(order_by=(op.func.arg,)) + if isinstance(func, translator._require_order_by) and not frame.order_by: + frame = frame.copy(order_by=(func.arg,)) # Time ranges need to be converted to microseconds. if isinstance(frame, ops.RangeWindowFrame): @@ -153,12 +155,12 @@ def window(translator, op): 'Rows with max lookback is not implemented for SQL-based backends.' ) - window_formatted = format_window_frame(translator, op.func, frame) + window_formatted = format_window_frame(translator, func, frame) - arg_formatted = translator.translate(op.func) + arg_formatted = translator.translate(func.__window_op__) result = f'{arg_formatted} {window_formatted}' - if isinstance(op.func, ops.RankBase): + if isinstance(func, ops.RankBase): return f'({result} - 1)' else: return result diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 6aa8c8792839..e56f1d64f5e8 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -298,6 +298,28 @@ def _arbitrary(translator, op): return f"ANY_VALUE({translator.translate(arg)})" +def _first(translator, op): + arg = op.arg + where = op.where + + if where is not None: + arg = ops.Where(where, arg, ibis.NA) + + arg = translator.translate(arg) + return f"ARRAY_AGG({arg} IGNORE NULLS)[SAFE_OFFSET(0)]" + + +def _last(translator, op): + arg = op.arg + where = op.where + + if where is not None: + arg = ops.Where(where, arg, ibis.NA) + + arg = translator.translate(arg) + return f"ARRAY_REVERSE(ARRAY_AGG({arg} IGNORE NULLS))[SAFE_OFFSET(0)]" + + def _truncate(kind, units): def truncator(translator, op): arg, unit = op.args @@ -662,6 +684,8 @@ def _interval_multiply(t, op): ops.Log: _log, ops.Log2: _log2, ops.Arbitrary: _arbitrary, + ops.First: _first, + ops.Last: _last, # Geospatial Columnar ops.GeoUnaryUnion: unary("ST_UNION_AGG"), # Geospatial diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 27b8072648e6..d23b024da286 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -1031,6 +1031,8 @@ def formatter(op, **kw): ops.ArrayCollect: "groupArray", ops.Count: "count", ops.CountDistinct: "uniq", + ops.First: "any", + ops.Last: "anyLast", # string operations ops.StringLength: "length", ops.Lowercase: "lower", @@ -1240,10 +1242,11 @@ def _window(op: ops.WindowFunction, **kw: Any): return translate_val(arg, **kw) window_formatted = format_window_frame(op, op.frame, **kw) - func_formatted = translate_val(op.func, **kw) + func = op.func.__window_op__ + func_formatted = translate_val(func, **kw) result = f'{func_formatted} {window_formatted}' - if isinstance(op.func, ops.RankBase): + if isinstance(func, ops.RankBase): return f"({result} - 1)" return result diff --git a/ibis/backends/dask/execution/reductions.py b/ibis/backends/dask/execution/reductions.py index ce586666c009..f8901fa9d703 100644 --- a/ibis/backends/dask/execution/reductions.py +++ b/ibis/backends/dask/execution/reductions.py @@ -23,6 +23,7 @@ import toolz from multipledispatch.variadic import Variadic +import ibis.common.exceptions as exc import ibis.expr.operations as ops from ibis.backends.dask.dispatch import execute_node from ibis.backends.dask.execution.util import make_selected_obj @@ -107,6 +108,16 @@ def execute_reduction_series_mask(op, data, mask, aggcontext=None, **kwargs): return aggcontext.agg(operand, type(op).__name__.lower()) +@execute_node.register( + (ops.First, ops.Last), ddgb.SeriesGroupBy, (ddgb.SeriesGroupBy, type(None)) +) +@execute_node.register((ops.First, ops.Last), dd.Series, (dd.Series, type(None))) +def execute_first_last_dask(op, data, mask, aggcontext=None, **kwargs): + raise exc.OperationNotDefinedError( + "Dask does not support first or last aggregations" + ) + + @execute_node.register( (ops.CountDistinct, ops.ApproxCountDistinct), ddgb.SeriesGroupBy, diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 77dd98758e0f..98a056cac3f1 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -419,6 +419,8 @@ def _map_merge(t, op): ops.MapMerge: _map_merge, ops.Hash: unary(sa.func.hash), ops.Median: reduction(sa.func.median), + ops.First: reduction(sa.func.first), + ops.Last: reduction(sa.func.last), } ) diff --git a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql index ce2aeb1d3813..f0e93516a778 100644 --- a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/first/out.sql @@ -1 +1 @@ -first_value(`double_col`) \ No newline at end of file +first_value(`double_col`) OVER (ORDER BY `id` ASC) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql index 4497cce69bc6..58c72df6a70b 100644 --- a/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_analytic_functions/test_analytic_exprs/last/out.sql @@ -1 +1 @@ -last_value(`double_col`) \ No newline at end of file +last_value(`double_col`) OVER (ORDER BY `id` ASC) \ No newline at end of file diff --git a/ibis/backends/impala/tests/test_analytic_functions.py b/ibis/backends/impala/tests/test_analytic_functions.py index 431cda33c1af..d22a66e0999e 100644 --- a/ibis/backends/impala/tests/test_analytic_functions.py +++ b/ibis/backends/impala/tests/test_analytic_functions.py @@ -19,9 +19,8 @@ def table(mockcon): pytest.param( lambda t: t.string_col.lead(default=0), id="lead_explicit_default" ), - pytest.param(lambda t: t.double_col.first(), id="first"), - pytest.param(lambda t: t.double_col.last(), id="last"), - # (t.double_col.nth(4), 'first_value(lag(double_col, 4 - 1))') + pytest.param(lambda t: t.double_col.first().over(order_by="id"), id="first"), + pytest.param(lambda t: t.double_col.last().over(order_by="id"), id="last"), pytest.param(lambda t: t.double_col.ntile(3), id="ntile"), pytest.param(lambda t: t.double_col.percent_rank(), id="percent_rank"), ], diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index 31b268312523..de092881e8bd 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -635,6 +635,16 @@ def execute_reduction_series_groupby(op, data, mask, aggcontext=None, **kwargs): return aggcontext.agg(data, type(op).__name__.lower()) +@execute_node.register(ops.First, SeriesGroupBy, type(None)) +def execute_first_series_groupby(op, data, mask, aggcontext=None, **kwargs): + return aggcontext.agg(data, lambda x: getattr(x, "iat", x)[0]) + + +@execute_node.register(ops.Last, SeriesGroupBy, type(None)) +def execute_last_series_groupby(op, data, mask, aggcontext=None, **kwargs): + return aggcontext.agg(data, lambda x: getattr(x, "iat", x)[-1]) + + variance_ddof = {'pop': 0, 'sample': 1} @@ -697,6 +707,20 @@ def execute_reduction_series_gb_mask(op, data, mask, aggcontext=None, **kwargs): ) +@execute_node.register(ops.First, SeriesGroupBy, SeriesGroupBy) +def execute_first_series_gb_mask(op, data, mask, aggcontext=None, **kwargs): + return aggcontext.agg( + data, functools.partial(_filtered_reduction, mask.obj, lambda x: x.iloc[0]) + ) + + +@execute_node.register(ops.Last, SeriesGroupBy, SeriesGroupBy) +def execute_last_series_gb_mask(op, data, mask, aggcontext=None, **kwargs): + return aggcontext.agg( + data, functools.partial(_filtered_reduction, mask.obj, lambda x: x.iloc[-1]) + ) + + @execute_node.register( (ops.CountDistinct, ops.ApproxCountDistinct), SeriesGroupBy, @@ -745,6 +769,18 @@ def execute_reduction_series_mask(op, data, mask, aggcontext=None, **kwargs): return aggcontext.agg(operand, type(op).__name__.lower()) +@execute_node.register(ops.First, pd.Series, (pd.Series, type(None))) +def execute_first_series_mask(op, data, mask, aggcontext=None, **kwargs): + operand = data[mask] if mask is not None else data + return aggcontext.agg(operand, lambda x: x.iloc[0]) + + +@execute_node.register(ops.Last, pd.Series, (pd.Series, type(None))) +def execute_last_series_mask(op, data, mask, aggcontext=None, **kwargs): + operand = data[mask] if mask is not None else data + return aggcontext.agg(operand, lambda x: x.iloc[-1]) + + @execute_node.register( (ops.CountDistinct, ops.ApproxCountDistinct), pd.Series, diff --git a/ibis/backends/pandas/tests/execution/test_window.py b/ibis/backends/pandas/tests/execution/test_window.py index c2b13cd8cb31..bfd926197a7c 100644 --- a/ibis/backends/pandas/tests/execution/test_window.py +++ b/ibis/backends/pandas/tests/execution/test_window.py @@ -107,15 +107,15 @@ def test_lag_delta(t, df, range_offset, default, range_window): def test_first(t, df): expr = t.dup_strings.first() result = expr.execute() - expected = df.dup_strings.iloc[np.repeat(0, len(df))].reset_index(drop=True) - tm.assert_series_equal(result, expected) + expected = df.dup_strings.iat[0] + assert result == expected def test_last(t, df): expr = t.dup_strings.last() result = expr.execute() - expected = df.dup_strings.iloc[np.repeat(-1, len(df))].reset_index(drop=True) - tm.assert_series_equal(result, expected) + expected = df.dup_strings.iat[-1] + assert result == expected def test_group_by_mutate_analytic(t, df): diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 7dada4c3a81e..88a1873f7701 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -643,6 +643,8 @@ def struct_column(op): ops.Variance: 'var', ops.CountDistinct: 'n_unique', ops.Median: 'median', + ops.First: 'first', + ops.Last: 'last', } for reduction in _reductions.keys(): diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index a29b9cdff0b1..f7d035bdd279 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -711,5 +711,7 @@ def _unnest(t, op): ops.Arbitrary: _arbitrary, ops.StructColumn: _struct_column, ops.StructField: _struct_field, + ops.First: reduction(sa.func.public._ibis_first), + ops.Last: reduction(sa.func.public._ibis_last), } ) diff --git a/ibis/backends/postgres/tests/test_functions.py b/ibis/backends/postgres/tests/test_functions.py index f0043bf4fc08..46776ad0307e 100644 --- a/ibis/backends/postgres/tests/test_functions.py +++ b/ibis/backends/postgres/tests/test_functions.py @@ -834,7 +834,9 @@ def test_cumulative_partitioned_ordered_window(alltypes, func, df): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize(('func', 'shift_amount'), [('lead', -1), ('lag', 1)]) +@pytest.mark.parametrize( + ('func', 'shift_amount'), [('lead', -1), ('lag', 1)], ids=["lead", "lag"] +) def test_analytic_shift_functions(alltypes, df, func, shift_amount): method = getattr(alltypes.double_col, func) expr = method(1) @@ -843,18 +845,17 @@ def test_analytic_shift_functions(alltypes, df, func, shift_amount): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize(('func', 'expected_index'), [('first', -1), ('last', 0)]) +@pytest.mark.parametrize( + ('func', 'expected_index'), [('first', -1), ('last', 0)], ids=["first", "last"] +) def test_first_last_value(alltypes, df, func, expected_index): col = alltypes.order_by(ibis.desc(alltypes.string_col)).double_col method = getattr(col, func) - expr = method() - result = expr.execute().rename('double_col') - expected = pd.Series( - df.double_col.iloc[expected_index], - index=pd.RangeIndex(len(df)), - name='double_col', - ) - tm.assert_series_equal(result, expected) + # test that we traverse into expression trees + expr = (1 + method()) - 1 + result = expr.execute() + expected = df.double_col.iloc[expected_index] + assert result == expected def test_null_column(alltypes): diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index ca19f358e51d..1ba76e6f78c5 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -549,7 +549,7 @@ def compile_notany(t, op, *args, aggcontext=None, **kwargs): if aggcontext is None: def fn(col): - return ~(F.max(col)) + return ~F.max(col) return compile_aggregator(t, op, *args, fn=fn, aggcontext=aggcontext, **kwargs) else: @@ -567,7 +567,7 @@ def compile_notall(t, op, *, aggcontext=None, **kwargs): if aggcontext is None: def fn(col): - return ~(F.min(col)) + return ~F.min(col) return compile_aggregator(t, op, fn=fn, aggcontext=aggcontext, **kwargs) else: @@ -711,6 +711,18 @@ def compile_arbitrary(t, op, **kwargs): return compile_aggregator(t, op, fn=fn, **kwargs) +@compiles(ops.First) +def compile_first(t, op, **kwargs): + fn = functools.partial(F.first, ignorenulls=True) + return compile_aggregator(t, op, fn=fn, **kwargs) + + +@compiles(ops.Last) +def compile_last(t, op, **kwargs): + fn = functools.partial(F.last, ignorenulls=True) + return compile_aggregator(t, op, fn=fn, **kwargs) + + @compiles(ops.Coalesce) def compile_coalesce(t, op, **kwargs): kwargs["raw"] = False # override to force column literals @@ -1198,7 +1210,8 @@ def compile_window_function(t, op, **kwargs): # If the operand is a shift op (e.g. lead, lag), Spark will set the window # bounds. Only set window bounds here if not a shift operation. - if not isinstance(op.func, ops.ShiftBase): + func = op.func.__window_op__ + if not isinstance(func, ops.ShiftBase): if op.frame.start is None: win_start = Window.unboundedPreceding else: @@ -1213,17 +1226,19 @@ def compile_window_function(t, op, **kwargs): else: pyspark_window = pyspark_window.rowsBetween(win_start, win_end) - func = op.func - if isinstance(func, (ops.NotAll, ops.NotAny)): + orig_func = func + if isinstance(orig_func, (ops.NotAll, ops.NotAny)): # For NotAll and NotAny, negation must be applied after .over(window) # Here we rewrite node to be its negation, and negate it back after # translation and window operation func = func.negate() + else: + func = orig_func result = t.translate(func, **kwargs, aggcontext=aggcontext).over(pyspark_window) - if isinstance(op.func, (ops.NotAll, ops.NotAny)): + if isinstance(orig_func, (ops.NotAll, ops.NotAny)): return ~result - elif isinstance(op.func, ops.RankBase): + elif isinstance(func, ops.RankBase): # result must be cast to long type for Rank / RowNumber return result.astype('long') - 1 else: diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 4bdd910e6c9d..610e33612d37 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -128,15 +128,17 @@ def _nth_value(t, op): def _arbitrary(t, op): - if op.how != "first": - raise com.UnsupportedOperationError( - "Snowflake only supports the `first` option for `.arbitrary()`" + if (how := op.how) == "first": + return t._reduction(lambda x: sa.func.get(sa.func.array_agg(x), 0), op) + elif how == "last": + return t._reduction( + lambda x: sa.func.get( + sa.func.array_agg(x), sa.func.array_size(sa.func.array_agg(x)) - 1 + ), + op, ) - - # we can't use any_value here because it respects nulls - # - # yes it's slower, but it's also consistent with every other backend - return t._reduction(sa.func.min, op) + else: + raise com.UnsupportedOperationError("how must be 'first' or 'last'") @compiles(Cast, "snowflake") @@ -367,6 +369,12 @@ def _group_concat(t, op): ), ops.NthValue: _nth_value, ops.Arbitrary: _arbitrary, + ops.First: reduction(lambda x: sa.func.get(sa.func.array_agg(x), 0)), + ops.Last: reduction( + lambda x: sa.func.get( + sa.func.array_agg(x), sa.func.array_size(sa.func.array_agg(x)) - 1 + ) + ), ops.StructColumn: lambda t, op: sa.func.object_construct_keep_null( *itertools.chain.from_iterable(zip(op.names, map(t.translate, op.values))) ), diff --git a/ibis/backends/sqlite/registry.py b/ibis/backends/sqlite/registry.py index a3e599bfa2b8..937416d84a6a 100644 --- a/ibis/backends/sqlite/registry.py +++ b/ibis/backends/sqlite/registry.py @@ -418,5 +418,11 @@ def _day_of_the_week_name(arg): lambda: 0.5 + sa.func.random() / sa.cast(-1 << 64, sa.REAL), 0 ), ops.Arbitrary: _arbitrary, + ops.First: lambda t, op: t.translate( + ops.Arbitrary(op.arg, where=op.where, how="first") + ), + ops.Last: lambda t, op: t.translate( + ops.Arbitrary(op.arg, where=op.where, how="last") + ), } ) diff --git a/ibis/backends/sqlite/tests/snapshots/test_functions/test_count_on_order_by/out.sql b/ibis/backends/sqlite/tests/snapshots/test_functions/test_count_on_order_by/out.sql new file mode 100644 index 000000000000..5961ebc98efc --- /dev/null +++ b/ibis/backends/sqlite/tests/snapshots/test_functions/test_count_on_order_by/out.sql @@ -0,0 +1,30 @@ +SELECT + COUNT(*) AS count +FROM ( + SELECT + t1."playerID" AS "playerID", + t1."yearID" AS "yearID", + t1.stint AS stint, + t1."teamID" AS "teamID", + t1."lgID" AS "lgID", + t1."G" AS "G", + t1."AB" AS "AB", + t1."R" AS "R", + t1."H" AS "H", + t1."X2B" AS "X2B", + t1."X3B" AS "X3B", + t1."HR" AS "HR", + t1."RBI" AS "RBI", + t1."SB" AS "SB", + t1."CS" AS "CS", + t1."BB" AS "BB", + t1."SO" AS "SO", + t1."IBB" AS "IBB", + t1."HBP" AS "HBP", + t1."SH" AS "SH", + t1."SF" AS "SF", + t1."GIDP" AS "GIDP" + FROM batting AS t1 + ORDER BY + t1."playerID" DESC +) AS t0 \ No newline at end of file diff --git a/ibis/backends/sqlite/tests/test_functions.py b/ibis/backends/sqlite/tests/test_functions.py index c967d4112736..3d2a7630ade9 100644 --- a/ibis/backends/sqlite/tests/test_functions.py +++ b/ibis/backends/sqlite/tests/test_functions.py @@ -706,11 +706,10 @@ def test_scalar_parameter(alltypes): tm.assert_series_equal(result, expected) -def test_count_on_order_by(con): +def test_count_on_order_by(con, snapshot): t = con.table("batting") sort_key = ibis.desc(t.playerID) sorted_table = t.order_by(sort_key) expr = sorted_table.count() - result = str(expr.compile().compile(compile_kwargs={'literal_binds': True})) - expected = "SELECT count(*) AS count \nFROM batting AS t0" - assert result == expected + result = str(ibis.to_sql(expr, dialect="sqlite")) + snapshot.assert_match(result, "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 5bf2bc0cdb33..8a872ea8ebef 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -591,7 +591,7 @@ def mean_and_std(v): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl( - ["bigquery", "snowflake", "trino"], + ["bigquery", "trino"], raises=com.UnsupportedOperationError, reason="backend only supports the `first` option for `.arbitrary()", ), @@ -629,6 +629,24 @@ def mean_and_std(v): ), ], ), + param( + lambda t, where: t.double_col.first(where=where), + lambda t, where: t.double_col[where].iloc[0], + id='first', + marks=pytest.mark.notimpl( + ["dask", "datafusion", "druid", "impala", "mssql", "mysql"], + raises=com.OperationNotDefinedError, + ), + ), + param( + lambda t, where: t.double_col.last(where=where), + lambda t, where: t.double_col[where].iloc[-1], + id='last', + marks=pytest.mark.notimpl( + ["dask", "datafusion", "druid", "impala", "mssql", "mysql"], + raises=com.OperationNotDefinedError, + ), + ), param( lambda t, where: t.bigint_col.bit_and(where=where), lambda t, where: np.bitwise_and.reduce(t.bigint_col[where].values), @@ -1171,7 +1189,7 @@ def test_group_concat( ], ) @mark.notimpl( - ["pandas", "dask"], + ["dask"], raises=NotImplementedError, reason="sorting on aggregations not yet implemented", ) @@ -1201,7 +1219,7 @@ def test_topk_op(alltypes, df, result_fn, expected_fn): ) ], ) -@mark.notimpl(["datafusion", "pandas"], raises=com.OperationNotDefinedError) +@mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @mark.broken( ["bigquery"], raises=GoogleBadRequest, @@ -1216,7 +1234,7 @@ def test_topk_op(alltypes, df, result_fn, expected_fn): ), ) @mark.notimpl( - ["pandas", "dask"], + ["dask"], raises=NotImplementedError, reason="sorting on aggregations not yet implemented", ) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ef89775183c4..9ffeb1aca8e9 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1053,7 +1053,7 @@ def test_pivot_wider(backend): param( "last", marks=pytest.mark.notimpl( - ["bigquery", "snowflake", "trino"], + ["bigquery", "trino"], raises=com.UnsupportedOperationError, reason="backend doesn't support how='last'", ), diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 0dfe56b8db64..4fc7831b4a42 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import partial +from typing import Literal import sqlalchemy as sa from sqlalchemy.ext.compiler import compiles @@ -227,6 +228,10 @@ def _array_filter(t, op): ) +def _first_last(t, op, *, offset: Literal[-1, 1]): + return sa.func.element_at(t._reduction(sa.func.array_agg, op), offset) + + operation_registry.update( { # conditional expressions @@ -401,6 +406,8 @@ def _array_filter(t, op): ), ops.StartsWith: fixed_arity(sa.func.starts_with, 2), ops.Argument: lambda _, op: sa.literal_column(op.name), + ops.First: partial(_first_last, offset=1), + ops.Last: partial(_first_last, offset=-1), } ) diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 55aeb87dc2f5..43bc70ce79c7 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -450,6 +450,15 @@ def _walk(op, frame): return _windowize(expr.op(), frame).to_expr() +def contains_first_or_last_agg(exprs): + def fn(node: ops.Node) -> tuple[bool, bool | None]: + if not isinstance(node, ops.Value): + return g.halt, None + return g.proceed, isinstance(node, (ops.First, ops.Last)) + + return any(g.traverse(fn, exprs)) + + def simplify_aggregation(agg): def _pushdown(nodes): subbed = [] @@ -464,7 +473,15 @@ def _pushdown(nodes): return valid, subbed - if isinstance(agg.table, ops.Selection) and not agg.table.selections: + table = agg.table + if ( + isinstance(table, ops.Selection) + and not table.selections + # more aggressive than necessary, a better solution would be to check + # whether the selections have any order sensitive aggregates that + # *depend on* the sort_keys + and not (table.sort_keys or contains_first_or_last_agg(table.selections)) + ): metrics_valid, lowered_metrics = _pushdown(agg.metrics) by_valid, lowered_by = _pushdown(agg.by) having_valid, lowered_having = _pushdown(agg.having) @@ -472,7 +489,7 @@ def _pushdown(nodes): if metrics_valid and by_valid and having_valid: valid_lowered_sort_keys = frozenset(lowered_metrics).union(lowered_by) return ops.Aggregation( - agg.table.table, + table.table, lowered_metrics, by=lowered_by, having=lowered_having, diff --git a/ibis/expr/operations/analytic.py b/ibis/expr/operations/analytic.py index e0ca9e78ef58..1839dfe25917 100644 --- a/ibis/expr/operations/analytic.py +++ b/ibis/expr/operations/analytic.py @@ -12,6 +12,10 @@ class Analytic(Value): output_shape = rlz.Shape.COLUMNAR + @property + def __window_op__(self): + return self + @public class ShiftBase(Analytic): diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 01c9a1f77137..b97aae291859 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -2,6 +2,7 @@ from public import public +import ibis.common.exceptions as exc import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute @@ -13,6 +14,10 @@ class Reduction(Value): output_shape = rlz.Shape.SCALAR + @property + def __window_op__(self): + return self + class Filterable(Value): where = rlz.optional(rlz.boolean) @@ -40,6 +45,42 @@ class Arbitrary(Filterable, Reduction): output_dtype = rlz.dtype_like('arg') +@public +class First(Filterable, Reduction): + """Retrieve the first element.""" + + arg = rlz.column(rlz.any) + output_dtype = rlz.dtype_like("arg") + + @property + def __window_op__(self): + import ibis.expr.operations as ops + + if self.where is not None: + raise exc.OperationNotDefinedError( + "FirstValue cannot be filtered in a window context" + ) + return ops.FirstValue(arg=self.arg) + + +@public +class Last(Filterable, Reduction): + """Retrieve the last element.""" + + arg = rlz.column(rlz.any) + output_dtype = rlz.dtype_like("arg") + + @property + def __window_op__(self): + import ibis.expr.operations as ops + + if self.where is not None: + raise exc.OperationNotDefinedError( + "LastValue cannot be filtered in a window context" + ) + return ops.LastValue(arg=self.arg) + + @public class BitAnd(Filterable, Reduction): """Aggregate bitwise AND operation. diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index c1d99eaf712a..aafeb1ce2afb 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1287,19 +1287,57 @@ def value_counts(self) -> ir.Table: .agg(**{f"{name}_count": lambda t: t.count()}) ) - def first(self) -> Column: + def first(self, where: ir.BooleanValue | None = None) -> Value: """Return the first value of a column. - Equivalent to SQL's `FIRST_VALUE` window function. + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.memtable({"chars": ["a", "b", "c", "d"]}) + >>> t + ┏━━━━━━━━┓ + ┃ chars ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ b │ + │ c │ + │ d │ + └────────┘ + >>> t.chars.first() + 'a' + >>> t.chars.first(where=t.chars != 'a') + 'b' """ - return ops.FirstValue(self).to_expr() + return ops.First(self, where=where).to_expr() - def last(self) -> Column: + def last(self, where: ir.BooleanValue | None = None) -> Value: """Return the last value of a column. - Equivalent to SQL's `LAST_VALUE` window function. + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.memtable({"chars": ["a", "b", "c", "d"]}) + >>> t + ┏━━━━━━━━┓ + ┃ chars ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ b │ + │ c │ + │ d │ + └────────┘ + >>> t.chars.last() + 'd' + >>> t.chars.last(where=t.chars != 'd') + 'c' """ - return ops.LastValue(self).to_expr() + return ops.Last(self, where=where).to_expr() def rank(self) -> ir.IntegerColumn: """Compute position of first element within each equal-value group in sorted order.