Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ir): merge window frames for bound analytic window functions with a subsequent over call #7790

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
37 changes: 35 additions & 2 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def calc_zscore(s):
),
pytest.mark.broken(
["flink"],
raises=Py4JJavaError,
reason="CalciteContextException: Argument to function 'NTILE' must be a literal",
raises=com.UnsupportedOperationError,
reason="Windows in Flink can only be ordered by a single time column",
),
],
),
Expand Down Expand Up @@ -1248,3 +1248,36 @@ def test_range_expression_bounds(backend):

assert not result.empty
assert len(result) == con.execute(t.count())


@pytest.mark.notimpl(["polars", "dask"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["clickhouse"],
reason="clickhouse doesn't implement percent_rank",
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["pandas"], reason="missing column during execution", raises=KeyError
)
@pytest.mark.broken(
["mssql"], reason="lack of support for booleans", raises=sa.exc.OperationalError
)
@pytest.mark.broken(
["pyspark"], reason="pyspark requires CURRENT ROW", raises=AnalysisException
)
def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df):
# GH #7631
t = alltypes
expr = t.int_col.percent_rank().over(ibis.window(group_by=t.int_col.notnull()))
result = expr.execute()

expected = (
df.sort_values("int_col")
.groupby(df["int_col"].notnull())
.apply(lambda df: (df.int_col.rank(method="min").sub(1).div(len(df) - 1)))
.T.reset_index(drop=True)
.iloc[:, 0]
.rename(expr.get_name())
)

backend.assert_series_equal(result, expected)
51 changes: 32 additions & 19 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import ibis.expr.operations.relations as rels
import ibis.expr.types as ir
from ibis import util
from ibis.common.deferred import _, deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import Eq, In, pattern
from ibis.common.deferred import deferred, var
from ibis.common.exceptions import ExpressionError, IbisTypeError, IntegrityError
from ibis.common.patterns import Eq, In, pattern, replace
from ibis.util import Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,25 +163,38 @@ def pushdown_selection_filters(parent, predicates):
return parent.copy(predicates=parent.predicates + tuple(simplified))


def windowize_function(expr, default_frame, merge_frames=False):
func, frame = var("func"), var("frame")
@replace(p.Analytic | p.Reduction)
def wrap_analytic(_, default_frame):
return ops.WindowFunction(_, default_frame)

wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame)
merge_windows = p.WindowFunction(func, frame) >> c.WindowFunction(
func,
frame.copy(
order_by=frame.order_by + default_frame.order_by,
group_by=frame.group_by + default_frame.group_by,
),
)

node = expr.op()
if merge_frames:
# it only happens in ibis.expr.groupby.GroupedTable, but the projector
# changes the windowization order to put everything here
node = node.replace(merge_windows, filter=p.Value)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction)
@replace(p.WindowFunction)
def merge_windows(_, default_frame):
if _.frame.start and default_frame.start and _.frame.start != default_frame.start:
raise ExpressionError(
"Unable to merge windows with conflicting `start` boundary"
)
if _.frame.end and default_frame.end and _.frame.end != default_frame.end:
raise ExpressionError("Unable to merge windows with conflicting `end` boundary")

start = _.frame.start or default_frame.start
end = _.frame.end or default_frame.end
group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by))

order_by = {}
for sort_key in _.frame.order_by + default_frame.order_by:
order_by[sort_key.expr] = sort_key.ascending
order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items())

frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by)
return ops.WindowFunction(_.func, frame)


def windowize_function(expr, default_frame):
ctx = {"default_frame": default_frame}
node = expr.op()
node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx)
return node.to_expr()


Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def to_torch(

def unbind(self) -> ir.Table:
"""Return an expression built on `UnboundTable` instead of backend-specific objects."""
from ibis.expr.analysis import p, c, _
from ibis.expr.analysis import p, c
from ibis.common.deferred import _

rule = p.DatabaseTable >> c.UnboundTable(name=_.name, schema=_.schema)
return self.op().replace(rule).to_expr()
Expand Down
4 changes: 1 addition & 3 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,9 +766,7 @@ def bind(table):
return expr

op = self.op()
if isinstance(op, ops.WindowFunction):
return op.func.to_expr().over(window)
elif isinstance(window, bl.WindowBuilder):
if isinstance(window, bl.WindowBuilder):
if table := an.find_first_base_table(self.op()):
return bind(table)
else:
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def _selectables(self, *exprs, **kwexprs):
order_by=bind_expr(self.table, self._order_by),
)
return [
an.windowize_function(e2, default_frame, merge_frames=True)
an.windowize_function(e2, default_frame)
for expr in exprs
for e1 in util.promote_list(expr)
for e2 in util.promote_list(table._ensure_expr(e1))
] + [
an.windowize_function(e, default_frame, merge_frames=True).name(k)
an.windowize_function(e, default_frame).name(k)
for k, expr in kwexprs.items()
for e in util.promote_list(table._ensure_expr(expr))
]
Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,7 +4338,8 @@ def _resolve_predicates(
table: Table, predicates
) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]:
import ibis.expr.types as ir
from ibis.expr.analysis import _, flatten_predicate, p
from ibis.common.deferred import _
from ibis.expr.analysis import flatten_predicate, p

# TODO(kszucs): clean this up, too much flattening and resolving happens here
predicates = [
Expand Down
25 changes: 21 additions & 4 deletions ibis/tests/expr/test_window_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.operations as ops
from ibis.common.exceptions import ExpressionError


def test_mutate_with_analytic_functions(alltypes):
Expand Down Expand Up @@ -48,10 +51,24 @@ def test_value_over_api(alltypes):
w1 = ibis.window(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
w2 = ibis.window(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])

expr = t.f.cumsum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have been misinterpreting this scenario because we used the cumulative_window() frame constructed in the cumsum() call, but we also provided explicit rows window boundaries. I changed the behaviour to raise due to window frame boundary conflict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense to me.

expected = t.f.cumsum().over(w1)
expr = t.f.sum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
expected = t.f.sum().over(w1)
assert expr.equals(expected)

expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.cumsum().over(w2)
expr = t.f.sum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.sum().over(w2)
assert expr.equals(expected)


def test_conflicting_window_boundaries(alltypes):
t = alltypes

with pytest.raises(ExpressionError, match="Unable to merge windows"):
t.f.cumsum().over(rows=(0, 1))


def test_rank_followed_by_over_call_merge_frames(alltypes):
t = alltypes
expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))
expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f).resolve(t)
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
assert expr1.equals(expr2)