Skip to content

Commit

Permalink
refactor(ir): replace Cumulative operations by adding where, `gro…
Browse files Browse the repository at this point in the history
…up_by` and `order_by` kwargs to cumulative APIs

BREAKING CHANGE: **Dask and Pandas only**; cumulative operations that relied on implicit ordering from prior operations such as calls to `table.order_by` may no longer work, pass `order_by=...` into the appropriate cumulative method to achieve the same behavior.
  • Loading branch information
cpcloud committed Sep 27, 2023
1 parent 98a6ae0 commit 26ffc68
Show file tree
Hide file tree
Showing 24 changed files with 80 additions and 318 deletions.
33 changes: 0 additions & 33 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,29 +311,6 @@ def _endswith(t, op):
return t.translate(op.arg).endswith(t.translate(op.end))


_cumulative_to_reduction = {
ops.CumulativeSum: ops.Sum,
ops.CumulativeMin: ops.Min,
ops.CumulativeMax: ops.Max,
ops.CumulativeMean: ops.Mean,
ops.CumulativeAny: ops.Any,
ops.CumulativeAll: ops.All,
}


def _cumulative_to_window(translator, op, frame):
klass = _cumulative_to_reduction[type(op)]
new_op = klass(*op.args)
new_expr = new_op.to_expr().name(op.name)
new_frame = frame.copy(start=None, end=0)

if type(new_op) in translator._rewrites:
new_expr = translator._rewrites[type(new_op)](new_expr)

# TODO(kszucs): rewrite to receive and return an ops.Node
return an.windowize_function(new_expr, frame=new_frame)


def _translate_window_boundary(boundary):
if boundary is None:
return None
Expand All @@ -350,10 +327,6 @@ def _translate_window_boundary(boundary):
def _window_function(t, window):
func = window.func.__window_op__

if isinstance(func, ops.CumulativeOp):
func = _cumulative_to_window(t, func, window.frame).op()
return t.translate(func)

reduction = t.translate(func)

# Some analytic functions need to have the expression of interest in
Expand Down Expand Up @@ -717,12 +690,6 @@ class array_filter(FunctionElement):
ops.CumeDist: unary(lambda _: sa.func.cume_dist()),
ops.NthValue: _nth_value,
ops.WindowFunction: _window_function,
ops.CumulativeMax: unary(sa.func.max),
ops.CumulativeMin: unary(sa.func.min),
ops.CumulativeSum: unary(sa.func.sum),
ops.CumulativeMean: unary(sa.func.avg),
ops.CumulativeAny: unary(sa.func.bool_or),
ops.CumulativeAll: unary(sa.func.bool_and),
}

geospatial_functions = {
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/base/sql/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
unary,
)
from ibis.backends.base.sql.registry.window import (
cumulative_to_window,
format_window_frame,
time_range_to_range_window,
)
Expand All @@ -30,7 +29,6 @@
"type_to_sql_string",
"reduction",
"unary",
"cumulative_to_window",
"format_window_frame",
"time_range_to_range_window",
)
31 changes: 0 additions & 31 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops

Expand All @@ -16,32 +15,6 @@
}


_cumulative_to_reduction = {
ops.CumulativeSum: ops.Sum,
ops.CumulativeMin: ops.Min,
ops.CumulativeMax: ops.Max,
ops.CumulativeMean: ops.Mean,
ops.CumulativeAny: ops.Any,
ops.CumulativeAll: ops.All,
}


def cumulative_to_window(translator, func, frame):
klass = _cumulative_to_reduction[type(func)]
func = klass(*func.args)

try:
rule = translator._rewrites[type(func)]
except KeyError:
pass
else:
func = rule(func)

frame = frame.copy(start=None, end=0)
expr = an.windowize_function(func.to_expr(), frame)
return expr.op()


def interval_boundary_to_integer(boundary):
if boundary is None:
return None
Expand Down Expand Up @@ -129,10 +102,6 @@ def window(translator, op):
f"{type(func)} is not supported in window functions"
)

if isinstance(func, ops.CumulativeOp):
arg = cumulative_to_window(translator, func, op.frame)
return translator.translate(arg)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
frame = op.frame
Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,6 @@ def fn(node, _, **kwargs):
lambda op, ctx: ops.Literal(value=params[op], dtype=ctx[x])
)

# rewrite cumulative functions to window functions, so that we don't have
# to think about handling them in the compiler, we need only compile window
# functions
replace_cumulative_ops = p.WindowFunction(
x @ p.Cumulative, y
) >> a.cumulative_to_window(x, y)

# replace the right side of InColumn into a scalar subquery for sql
# backends
replace_in_column_with_table_array_view = p.InColumn(..., y) >> _.copy(
Expand Down Expand Up @@ -126,7 +119,6 @@ def fn(node, _, **kwargs):

op = op.replace(
replace_literals
| replace_cumulative_ops
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| replace_notexists_subquery_with_not_exists
Expand Down
23 changes: 0 additions & 23 deletions ibis/backends/dask/execution/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import operator
import re
from typing import TYPE_CHECKING, Any, Callable, NoReturn

import dask.dataframe as dd
Expand Down Expand Up @@ -125,7 +124,6 @@ def get_aggcontext_window(
output_type=output_type,
)
elif frame.start is not None:
assert not isinstance(operand, ops.CumulativeOp)
if isinstance(frame, ops.RowsWindowFrame):
max_lookback = frame.max_lookback
else:
Expand Down Expand Up @@ -361,27 +359,6 @@ def execute_window_op(
return result


@execute_node.register(
(ops.CumulativeSum, ops.CumulativeMax, ops.CumulativeMin),
(dd.Series, ddgb.SeriesGroupBy),
)
def execute_series_cumulative_sum_min_max(op, data, **kwargs):
typename = type(op).__name__
method_name = (
re.match(r"^Cumulative([A-Za-z_][A-Za-z0-9_]*)$", typename).group(1).lower()
)
method = getattr(data, f"cum{method_name}")
return method()


@execute_node.register(ops.CumulativeMean, (dd.Series, ddgb.SeriesGroupBy))
def execute_series_cumulative_mean(op, data, **kwargs):
# TODO: Doesn't handle the case where we've grouped/sorted by. Handling
# this here would probably require a refactor.
# Dask equivalent of Pandas DataFrame.rolling
return data.rolling(window=len(data), min_periods=1).mean()


@execute_node.register(
(ops.Lead, ops.Lag),
(dd.Series, ddgb.SeriesGroupBy),
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/dask/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,18 @@ def test_batting_quantile(players, players_df):

@pytest.mark.parametrize("op", ["sum", "min", "max", "mean"])
def test_batting_specific_cumulative(batting, batting_df, op, sort_kind):
ibis_method = methodcaller(f"cum{op}")
expr = ibis_method(batting.order_by([batting.yearID]).G)
ibis_method = methodcaller(f"cum{op}", order_by=batting.yearID)
expr = ibis_method(batting.G)
result = expr.execute().astype("float64")

pandas_method = methodcaller(op)
expected = pandas_method(
batting_df[["G", "yearID"]]
.sort_values("yearID", kind=sort_kind)
.G.rolling(len(batting_df), min_periods=1)
).reset_index(drop=True)
expected = expected.compute()
tm.assert_series_equal(result, expected.rename(f"Cumulative{op.capitalize()}(G)"))
)
expected = expected.compute().sort_index().reset_index(drop=True)
tm.assert_series_equal(result, expected.rename(f"{op.capitalize()}(G)"))


def test_batting_cumulative(batting, batting_df, sort_kind):
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_batting_cumulative_partitioned(batting, batting_df, sort_kind):
order_by = "yearID"

t = batting
expr = t.G.sum().over(ibis.cumulative_window(order_by=order_by, group_by=group_by))
expr = t.G.cumsum(order_by=order_by, group_by=group_by)
expr = t.mutate(cumulative=expr)
result = expr.execute()

Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,6 @@ def _try_cast(t, op):

_invalid_operations = {
# ibis.expr.operations.analytic
ops.CumulativeAll,
ops.CumulativeAny,
ops.CumulativeOp,
ops.NTile,
# ibis.expr.operations.strings
ops.Translate,
Expand Down
11 changes: 1 addition & 10 deletions ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.registry import (
fixed_arity,
helpers,
unary,
window,
)
from ibis.backends.base.sql.registry import fixed_arity, helpers, unary
from ibis.backends.base.sql.registry import (
operation_registry as base_operation_registry,
)
Expand Down Expand Up @@ -181,10 +176,6 @@ def _window(translator: ExprTranslator, op: ops.Node) -> str:
f"{type(func)} is not supported in window functions"
)

if isinstance(func, ops.CumulativeOp):
arg = window.cumulative_to_window(translator, func, op.frame)
return translator.translate(arg)

if isinstance(frame, ops.RowsWindowFrame):
if frame.max_lookback is not None:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
FROM `alltypes` t0
6 changes: 0 additions & 6 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,6 @@ def compiles_mysql_trim(element, compiler, **kw):
)

_invalid_operations = {
ops.CumulativeAll,
ops.CumulativeAny,
ops.CumulativeMax,
ops.CumulativeMean,
ops.CumulativeMin,
ops.CumulativeSum,
ops.NTile,
}

Expand Down
47 changes: 0 additions & 47 deletions ibis/backends/pandas/execution/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import operator
import re
from typing import TYPE_CHECKING, Any, Callable, NoReturn

import numpy as np
Expand Down Expand Up @@ -175,7 +174,6 @@ def get_aggcontext_window(
output_type=output_type,
)
elif frame.start is not None:
assert not isinstance(operand, ops.CumulativeOp)
if isinstance(frame, ops.RowsWindowFrame):
max_lookback = frame.max_lookback
else:
Expand Down Expand Up @@ -402,51 +400,6 @@ def execute_window_op(
return result


@execute_node.register(
(ops.CumulativeSum, ops.CumulativeMax, ops.CumulativeMin),
(pd.Series, SeriesGroupBy),
)
def execute_series_cumulative_sum_min_max(op, data, **kwargs):
typename = type(op).__name__
method_name = (
re.match(r"^Cumulative([A-Za-z_][A-Za-z0-9_]*)$", typename).group(1).lower()
)
method = getattr(data, f"cum{method_name}")
return method()


@execute_node.register(ops.CumulativeMean, (pd.Series, SeriesGroupBy))
def execute_series_cumulative_mean(op, data, **kwargs):
# TODO: Doesn't handle the case where we've grouped/sorted by. Handling
# this here would probably require a refactor.
return data.expanding().mean()


@execute_node.register(ops.CumulativeOp, (pd.Series, SeriesGroupBy))
def execute_series_cumulative_op(op, data, aggcontext=None, **kwargs):
assert aggcontext is not None, f"aggcontext is none in {type(op)} operation"
typename = type(op).__name__
match = re.match(r"^Cumulative([A-Za-z_][A-Za-z0-9_]*)$", typename)
if match is None:
raise ValueError(f"Unknown operation {typename}")

try:
(operation_name,) = match.groups()
except ValueError:
raise ValueError(f"More than one operation name found in {typename} class")

dtype = op.to_expr().type().to_pandas()
assert isinstance(aggcontext, agg_ctx.Cumulative), f"Got {type()}"
result = aggcontext.agg(data, operation_name.lower())

# all expanding window operations are required to be int64 or float64, so
# we need to cast back to preserve the type of the operation
try:
return result.astype(dtype)
except TypeError:
return result


def post_lead_lag(result, default):
if not pd.isnull(default):
return result.fillna(default)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pandas/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def test_batting_approx_median(players, players_df):

@pytest.mark.parametrize("op", ["sum", "mean", "min", "max"])
def test_batting_specific_cumulative(batting, batting_df, op, sort_kind):
ibis_method = methodcaller(f"cum{op}")
expr = ibis_method(batting.order_by([batting.yearID]).G)
ibis_method = methodcaller(f"cum{op}", order_by=batting.yearID)
expr = ibis_method(batting.G)
result = expr.execute().astype("float64")

pandas_method = methodcaller(op)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ def _array_filter(t, op):
lambda arg: sa.func.trim(sa.func.to_char(arg, "Day")), 1
),
ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3),
ops.CumulativeAll: unary(sa.func.bool_and),
ops.CumulativeAny: unary(sa.func.bool_or),
# array operations
ops.ArrayLength: unary(sa.func.cardinality),
ops.ArrayCollect: reduction(sa.func.array_agg),
Expand Down
Loading

0 comments on commit 26ffc68

Please sign in to comment.