From e12ce8dc19d05e44242bc0871481bd9ff1cf86b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 17 Dec 2023 22:44:58 +0100 Subject: [PATCH] fix(ir): merge window frames for bound analytic window functions with a subsequent over call --- ibis/backends/tests/test_window.py | 18 ++++++++++++ ibis/expr/analysis.py | 36 ++++++++++++------------ ibis/expr/types/generic.py | 6 ++-- ibis/tests/expr/test_window_functions.py | 7 +++++ 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index b10b80527777..f9278807e5d4 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -1248,3 +1248,21 @@ def test_range_expression_bounds(backend): assert not result.empty assert len(result) == con.execute(t.count()) + + +def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): + # GH #7631 + t = alltypes + expr = t.int_col.percent_rank().over(ibis.window(group_by=t.int_col.notnull())) + result = expr.execute() + + expected = ( + df.sort_values("int_col") + .groupby(df["int_col"].notnull()) + .apply(lambda df: (df.int_col.rank(method="min").sub(1).div(len(df) - 1))) + .T.reset_index(drop=True) + .iloc[:, 0] + .rename(expr.get_name()) + ) + + backend.assert_series_equal(result, expected) diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 1bdc120a9ef4..17b9beaa5506 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -10,9 +10,9 @@ import ibis.expr.operations.relations as rels import ibis.expr.types as ir from ibis import util -from ibis.common.deferred import _, deferred, var +from ibis.common.deferred import deferred, var from ibis.common.exceptions import IbisTypeError, IntegrityError -from ibis.common.patterns import Eq, In, pattern +from ibis.common.patterns import Eq, In, pattern, replace from ibis.util import Namespace if TYPE_CHECKING: @@ -163,25 +163,25 @@ def pushdown_selection_filters(parent, predicates): return parent.copy(predicates=parent.predicates + tuple(simplified)) -def windowize_function(expr, default_frame, merge_frames=False): - func, frame = var("func"), var("frame") - - wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame) - merge_windows = p.WindowFunction(func, frame) >> c.WindowFunction( - func, - frame.copy( - order_by=frame.order_by + default_frame.order_by, - group_by=frame.group_by + default_frame.group_by, - ), - ) +@replace(p.Analytic | p.Reduction) +def wrap_analytic(_, default_frame): + return ops.WindowFunction(_, default_frame) + + +@replace(p.WindowFunction) +def merge_windows(_, default_frame): + group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by)) + order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by)) + frame = _.frame.copy(group_by=group_by, order_by=order_by) + return ops.WindowFunction(_.func, frame) + +def windowize_function(expr, default_frame, merge_frames=False): + ctx = {"default_frame": default_frame} node = expr.op() if merge_frames: - # it only happens in ibis.expr.groupby.GroupedTable, but the projector - # changes the windowization order to put everything here - node = node.replace(merge_windows, filter=p.Value) - node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction) - + node = node.replace(merge_windows, filter=p.Value, context=ctx) + node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx) return node.to_expr() diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index f41ccc686216..f0d573f323f8 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -758,7 +758,7 @@ def over( def bind(table): frame = window.bind(table) - expr = an.windowize_function(self, frame) + expr = an.windowize_function(self, frame, merge_frames=True) if expr.equals(self): raise com.IbisTypeError( "No reduction or analytic function found to construct a window expression" @@ -766,9 +766,7 @@ def bind(table): return expr op = self.op() - if isinstance(op, ops.WindowFunction): - return op.func.to_expr().over(window) - elif isinstance(window, bl.WindowBuilder): + if isinstance(window, bl.WindowBuilder): if table := an.find_first_base_table(self.op()): return bind(table) else: diff --git a/ibis/tests/expr/test_window_functions.py b/ibis/tests/expr/test_window_functions.py index e03359c13c77..8ee6c5617ed3 100644 --- a/ibis/tests/expr/test_window_functions.py +++ b/ibis/tests/expr/test_window_functions.py @@ -55,3 +55,10 @@ def test_value_over_api(alltypes): expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f]) expected = t.f.cumsum().over(w2) assert expr.equals(expected) + + +def test_rank_followed_by_over_call_merge_frames(alltypes): + t = alltypes + expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull())) + expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f).resolve(t) + assert expr1.equals(expr2)