Skip to content

Commit

Permalink
refactor(analysis): always merge frames during windowization
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 17, 2023
1 parent 7ea9229 commit 2c02b3b
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 13 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,13 @@ def test_range_expression_bounds(backend):
assert len(result) == con.execute(t.count())


@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["clickhouse"],
reason="clickhouse doesn't implement percent_rank",
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["dask"], raises=NotImplementedError)
def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df):
# GH #7631
t = alltypes
Expand Down
14 changes: 10 additions & 4 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,23 @@ def wrap_analytic(_, default_frame):

@replace(p.WindowFunction)
def merge_windows(_, default_frame):
if _.frame.start is not None and default_frame.start is not None:
raise ValueError("Unable to merge windows with conflicting start")
if _.frame.end is not None and default_frame.end is not None:
raise ValueError("Unable to merge windows with conflicting end")

start = _.frame.start or default_frame.start
end = _.frame.end or default_frame.end
group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by))
order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by))
frame = _.frame.copy(group_by=group_by, order_by=order_by)
frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by)
return ops.WindowFunction(_.func, frame)


def windowize_function(expr, default_frame, merge_frames=False):
def windowize_function(expr, default_frame):
ctx = {"default_frame": default_frame}
node = expr.op()
if merge_frames:
node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx)
return node.to_expr()

Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def to_torch(

def unbind(self) -> ir.Table:
"""Return an expression built on `UnboundTable` instead of backend-specific objects."""
from ibis.expr.analysis import p, c, _
from ibis.expr.analysis import p, c
from ibis.common.deferred import _

rule = p.DatabaseTable >> c.UnboundTable(name=_.name, schema=_.schema)
return self.op().replace(rule).to_expr()
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def over(

def bind(table):
frame = window.bind(table)
expr = an.windowize_function(self, frame, merge_frames=True)
expr = an.windowize_function(self, frame)
if expr.equals(self):
raise com.IbisTypeError(
"No reduction or analytic function found to construct a window expression"
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def _selectables(self, *exprs, **kwexprs):
order_by=bind_expr(self.table, self._order_by),
)
return [
an.windowize_function(e2, default_frame, merge_frames=True)
an.windowize_function(e2, default_frame)
for expr in exprs
for e1 in util.promote_list(expr)
for e2 in util.promote_list(table._ensure_expr(e1))
] + [
an.windowize_function(e, default_frame, merge_frames=True).name(k)
an.windowize_function(e, default_frame).name(k)
for k, expr in kwexprs.items()
for e in util.promote_list(table._ensure_expr(expr))
]
Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,7 +4338,8 @@ def _resolve_predicates(
table: Table, predicates
) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]:
import ibis.expr.types as ir
from ibis.expr.analysis import _, flatten_predicate, p
from ibis.common.deferred import _
from ibis.expr.analysis import flatten_predicate, p

# TODO(kszucs): clean this up, too much flattening and resolving happens here
predicates = [
Expand Down
17 changes: 13 additions & 4 deletions ibis/tests/expr/test_window_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.operations as ops

Expand Down Expand Up @@ -48,15 +50,22 @@ def test_value_over_api(alltypes):
w1 = ibis.window(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
w2 = ibis.window(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])

expr = t.f.cumsum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
expected = t.f.cumsum().over(w1)
expr = t.f.sum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
expected = t.f.sum().over(w1)
assert expr.equals(expected)

expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.cumsum().over(w2)
expr = t.f.sum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.sum().over(w2)
assert expr.equals(expected)


def test_conflicting_window_boundaries(alltypes):
t = alltypes

with pytest.raises(ValueError):
t.f.cumsum().over(rows=(0, 1))


def test_rank_followed_by_over_call_merge_frames(alltypes):
t = alltypes
expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))
Expand Down

0 comments on commit 2c02b3b

Please sign in to comment.