Skip to content

Commit

Permalink
feat(api): add first/last reduction APIs
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `Column.first()`/`Column.last()` are now reductions by default. Code running these expressions in isolation will no longer be windowed over the entire table. Code using this function in `select`-based APIs should function unchanged.
  • Loading branch information
cpcloud authored and kszucs committed May 4, 2023
1 parent 92d979e commit 8c01980
Show file tree
Hide file tree
Showing 26 changed files with 336 additions and 69 deletions.
16 changes: 9 additions & 7 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,18 @@ def _translate_window_boundary(boundary):


def _window_function(t, window):
if isinstance(window.func, ops.CumulativeOp):
func = _cumulative_to_window(t, window.func, window.frame).op()
func = window.func.__window_op__

if isinstance(func, ops.CumulativeOp):
func = _cumulative_to_window(t, func, window.frame).op()
return t.translate(func)

reduction = t.translate(window.func)
reduction = t.translate(func)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(window.func, t._require_order_by) and not window.frame.order_by:
order_by = t.translate(window.func.arg) # .args[0])
if isinstance(func, t._require_order_by) and not window.frame.order_by:
order_by = t.translate(func.arg) # .args[0])
else:
order_by = [t.translate(arg) for arg in window.frame.order_by]

Expand All @@ -361,7 +363,7 @@ def _window_function(t, window):
else:
raise NotImplementedError(type(window.frame))

if t._forbids_frame_clause and isinstance(window.func, t._forbids_frame_clause):
if t._forbids_frame_clause and isinstance(func, t._forbids_frame_clause):
# some functions on some backends don't support frame clauses
additional_params = {}
else:
Expand All @@ -373,7 +375,7 @@ def _window_function(t, window):
reduction, partition_by=partition_by, order_by=order_by, **additional_params
)

if isinstance(window.func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
if isinstance(func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
return result - 1
else:
return result
Expand Down
20 changes: 11 additions & 9 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,22 @@ def window(translator, op):
ops.ApproxCountDistinct,
)

if isinstance(op.func, _unsupported_reductions):
func = op.func.__window_op__

if isinstance(func, _unsupported_reductions):
raise com.UnsupportedOperationError(
f'{type(op.func)} is not supported in window functions'
f'{type(func)} is not supported in window functions'
)

if isinstance(op.func, ops.CumulativeOp):
arg = cumulative_to_window(translator, op.func, op.frame)
if isinstance(func, ops.CumulativeOp):
arg = cumulative_to_window(translator, func, op.frame)
return translator.translate(arg)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
frame = op.frame
if isinstance(op.func, translator._require_order_by) and not frame.order_by:
frame = frame.copy(order_by=(op.func.arg,))
if isinstance(func, translator._require_order_by) and not frame.order_by:
frame = frame.copy(order_by=(func.arg,))

# Time ranges need to be converted to microseconds.
if isinstance(frame, ops.RangeWindowFrame):
Expand All @@ -153,12 +155,12 @@ def window(translator, op):
'Rows with max lookback is not implemented for SQL-based backends.'
)

window_formatted = format_window_frame(translator, op.func, frame)
window_formatted = format_window_frame(translator, func, frame)

arg_formatted = translator.translate(op.func)
arg_formatted = translator.translate(func.__window_op__)
result = f'{arg_formatted} {window_formatted}'

if isinstance(op.func, ops.RankBase):
if isinstance(func, ops.RankBase):
return f'({result} - 1)'
else:
return result
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,28 @@ def _arbitrary(translator, op):
return f"ANY_VALUE({translator.translate(arg)})"


def _first(translator, op):
arg = op.arg
where = op.where

if where is not None:
arg = ops.Where(where, arg, ibis.NA)

arg = translator.translate(arg)
return f"ARRAY_AGG({arg} IGNORE NULLS)[SAFE_OFFSET(0)]"


def _last(translator, op):
arg = op.arg
where = op.where

if where is not None:
arg = ops.Where(where, arg, ibis.NA)

arg = translator.translate(arg)
return f"ARRAY_REVERSE(ARRAY_AGG({arg} IGNORE NULLS))[SAFE_OFFSET(0)]"


def _truncate(kind, units):
def truncator(translator, op):
arg, unit = op.args
Expand Down Expand Up @@ -662,6 +684,8 @@ def _interval_multiply(t, op):
ops.Log: _log,
ops.Log2: _log2,
ops.Arbitrary: _arbitrary,
ops.First: _first,
ops.Last: _last,
# Geospatial Columnar
ops.GeoUnaryUnion: unary("ST_UNION_AGG"),
# Geospatial
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ def formatter(op, **kw):
ops.ArrayCollect: "groupArray",
ops.Count: "count",
ops.CountDistinct: "uniq",
ops.First: "any",
ops.Last: "anyLast",
# string operations
ops.StringLength: "length",
ops.Lowercase: "lower",
Expand Down Expand Up @@ -1240,10 +1242,11 @@ def _window(op: ops.WindowFunction, **kw: Any):
return translate_val(arg, **kw)

window_formatted = format_window_frame(op, op.frame, **kw)
func_formatted = translate_val(op.func, **kw)
func = op.func.__window_op__
func_formatted = translate_val(func, **kw)
result = f'{func_formatted} {window_formatted}'

if isinstance(op.func, ops.RankBase):
if isinstance(func, ops.RankBase):
return f"({result} - 1)"

return result
Expand Down
11 changes: 11 additions & 0 deletions ibis/backends/dask/execution/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import toolz
from multipledispatch.variadic import Variadic

import ibis.common.exceptions as exc
import ibis.expr.operations as ops
from ibis.backends.dask.dispatch import execute_node
from ibis.backends.dask.execution.util import make_selected_obj
Expand Down Expand Up @@ -107,6 +108,16 @@ def execute_reduction_series_mask(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(operand, type(op).__name__.lower())


@execute_node.register(
(ops.First, ops.Last), ddgb.SeriesGroupBy, (ddgb.SeriesGroupBy, type(None))
)
@execute_node.register((ops.First, ops.Last), dd.Series, (dd.Series, type(None)))
def execute_first_last_dask(op, data, mask, aggcontext=None, **kwargs):
raise exc.OperationNotDefinedError(
"Dask does not support first or last aggregations"
)


@execute_node.register(
(ops.CountDistinct, ops.ApproxCountDistinct),
ddgb.SeriesGroupBy,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ def _map_merge(t, op):
ops.MapMerge: _map_merge,
ops.Hash: unary(sa.func.hash),
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
}
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
first_value(`double_col`)
first_value(`double_col`) OVER (ORDER BY `id` ASC)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
last_value(`double_col`)
last_value(`double_col`) OVER (ORDER BY `id` ASC)
5 changes: 2 additions & 3 deletions ibis/backends/impala/tests/test_analytic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ def table(mockcon):
pytest.param(
lambda t: t.string_col.lead(default=0), id="lead_explicit_default"
),
pytest.param(lambda t: t.double_col.first(), id="first"),
pytest.param(lambda t: t.double_col.last(), id="last"),
# (t.double_col.nth(4), 'first_value(lag(double_col, 4 - 1))')
pytest.param(lambda t: t.double_col.first().over(order_by="id"), id="first"),
pytest.param(lambda t: t.double_col.last().over(order_by="id"), id="last"),
pytest.param(lambda t: t.double_col.ntile(3), id="ntile"),
pytest.param(lambda t: t.double_col.percent_rank(), id="percent_rank"),
],
Expand Down
36 changes: 36 additions & 0 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,16 @@ def execute_reduction_series_groupby(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(data, type(op).__name__.lower())


@execute_node.register(ops.First, SeriesGroupBy, type(None))
def execute_first_series_groupby(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(data, lambda x: getattr(x, "iat", x)[0])


@execute_node.register(ops.Last, SeriesGroupBy, type(None))
def execute_last_series_groupby(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(data, lambda x: getattr(x, "iat", x)[-1])


variance_ddof = {'pop': 0, 'sample': 1}


Expand Down Expand Up @@ -697,6 +707,20 @@ def execute_reduction_series_gb_mask(op, data, mask, aggcontext=None, **kwargs):
)


@execute_node.register(ops.First, SeriesGroupBy, SeriesGroupBy)
def execute_first_series_gb_mask(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(
data, functools.partial(_filtered_reduction, mask.obj, lambda x: x.iloc[0])
)


@execute_node.register(ops.Last, SeriesGroupBy, SeriesGroupBy)
def execute_last_series_gb_mask(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(
data, functools.partial(_filtered_reduction, mask.obj, lambda x: x.iloc[-1])
)


@execute_node.register(
(ops.CountDistinct, ops.ApproxCountDistinct),
SeriesGroupBy,
Expand Down Expand Up @@ -745,6 +769,18 @@ def execute_reduction_series_mask(op, data, mask, aggcontext=None, **kwargs):
return aggcontext.agg(operand, type(op).__name__.lower())


@execute_node.register(ops.First, pd.Series, (pd.Series, type(None)))
def execute_first_series_mask(op, data, mask, aggcontext=None, **kwargs):
operand = data[mask] if mask is not None else data
return aggcontext.agg(operand, lambda x: x.iloc[0])


@execute_node.register(ops.Last, pd.Series, (pd.Series, type(None)))
def execute_last_series_mask(op, data, mask, aggcontext=None, **kwargs):
operand = data[mask] if mask is not None else data
return aggcontext.agg(operand, lambda x: x.iloc[-1])


@execute_node.register(
(ops.CountDistinct, ops.ApproxCountDistinct),
pd.Series,
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/pandas/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ def test_lag_delta(t, df, range_offset, default, range_window):
def test_first(t, df):
expr = t.dup_strings.first()
result = expr.execute()
expected = df.dup_strings.iloc[np.repeat(0, len(df))].reset_index(drop=True)
tm.assert_series_equal(result, expected)
expected = df.dup_strings.iat[0]
assert result == expected


def test_last(t, df):
expr = t.dup_strings.last()
result = expr.execute()
expected = df.dup_strings.iloc[np.repeat(-1, len(df))].reset_index(drop=True)
tm.assert_series_equal(result, expected)
expected = df.dup_strings.iat[-1]
assert result == expected


def test_group_by_mutate_analytic(t, df):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ def struct_column(op):
ops.Variance: 'var',
ops.CountDistinct: 'n_unique',
ops.Median: 'median',
ops.First: 'first',
ops.Last: 'last',
}

for reduction in _reductions.keys():
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,5 +711,7 @@ def _unnest(t, op):
ops.Arbitrary: _arbitrary,
ops.StructColumn: _struct_column,
ops.StructField: _struct_field,
ops.First: reduction(sa.func.public._ibis_first),
ops.Last: reduction(sa.func.public._ibis_last),
}
)
21 changes: 11 additions & 10 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,9 @@ def test_cumulative_partitioned_ordered_window(alltypes, func, df):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(('func', 'shift_amount'), [('lead', -1), ('lag', 1)])
@pytest.mark.parametrize(
('func', 'shift_amount'), [('lead', -1), ('lag', 1)], ids=["lead", "lag"]
)
def test_analytic_shift_functions(alltypes, df, func, shift_amount):
method = getattr(alltypes.double_col, func)
expr = method(1)
Expand All @@ -843,18 +845,17 @@ def test_analytic_shift_functions(alltypes, df, func, shift_amount):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(('func', 'expected_index'), [('first', -1), ('last', 0)])
@pytest.mark.parametrize(
('func', 'expected_index'), [('first', -1), ('last', 0)], ids=["first", "last"]
)
def test_first_last_value(alltypes, df, func, expected_index):
col = alltypes.order_by(ibis.desc(alltypes.string_col)).double_col
method = getattr(col, func)
expr = method()
result = expr.execute().rename('double_col')
expected = pd.Series(
df.double_col.iloc[expected_index],
index=pd.RangeIndex(len(df)),
name='double_col',
)
tm.assert_series_equal(result, expected)
# test that we traverse into expression trees
expr = (1 + method()) - 1
result = expr.execute()
expected = df.double_col.iloc[expected_index]
assert result == expected


def test_null_column(alltypes):
Expand Down
Loading

0 comments on commit 8c01980

Please sign in to comment.