Skip to content

Commit

Permalink
fix(ir): ensure that windowization directly wraps the reduction/analy…
Browse files Browse the repository at this point in the history
…tic function
  • Loading branch information
kszucs authored and cpcloud committed Oct 13, 2023
1 parent e31e8fd commit 772df36
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 43 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def calc_zscore(s):
id="cumnotany",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
),
Expand Down Expand Up @@ -239,7 +239,7 @@ def calc_zscore(s):
id="cumnotall",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
),
Expand Down
28 changes: 0 additions & 28 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,38 +323,10 @@ def pushdown_aggregation_filters(op, predicates):
return ops.Selection(op, [], predicates)


# TODO(kszucs): use ibis.expr.analysis.substitute instead
def propagate_down_window(func: ops.Value, frame: ops.WindowFrame):
import ibis.expr.operations as ops

clean_args = []
for arg in func.args:
if isinstance(arg, ops.Value) and not isinstance(func, ops.WindowFunction):
arg = propagate_down_window(arg, frame)
if isinstance(arg, ops.Analytic):
arg = ops.WindowFunction(arg, frame)
clean_args.append(arg)

return type(func)(*clean_args)


def windowize_function(expr, default_frame):
func = var("func")
frame = var("frame")

wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame)
merge_frames = p.WindowFunction(func, frame) >> c.WindowFunction(
func,
frame.copy(
order_by=frame.order_by + default_frame.order_by,
group_by=frame.group_by + default_frame.group_by,
),
)

node = expr.op()
node = node.replace(merge_frames, filter=p.Value)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction)

return node.to_expr()


Expand Down
12 changes: 4 additions & 8 deletions ibis/expr/operations/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import ibis.expr.rules as rlz
from ibis.common.patterns import CoercionError
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.analytic import Analytic
from ibis.expr.operations.analytic import Analytic # noqa: TCH001
from ibis.expr.operations.core import Column, Value
from ibis.expr.operations.generic import Literal
from ibis.expr.operations.numeric import Negate
from ibis.expr.operations.reductions import Reduction
from ibis.expr.operations.reductions import Reduction # noqa: TCH001
from ibis.expr.operations.relations import Relation # noqa: TCH001
from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001

Expand Down Expand Up @@ -122,19 +122,15 @@ class RangeWindowFrame(WindowFrame):

@public
class WindowFunction(Value):
func: Value
func: Analytic | Reduction
frame: WindowFrame

dtype = rlz.dtype_like("func")
shape = ds.columnar

def __init__(self, func, frame):
from ibis.expr.analysis import propagate_down_window, shares_all_roots
from ibis.expr.analysis import shares_all_roots

if not func.find((Reduction, Analytic)):
raise com.IbisTypeError("Window function expression must be analytic")

func = propagate_down_window(func, frame)
if not shares_all_roots(func, frame):
raise com.RelationError(
"Window function expressions doesn't fully originate from the "
Expand Down
13 changes: 8 additions & 5 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,23 @@ def over(

def bind(table):
frame = window.bind(table)
return ops.WindowFunction(self, frame).to_expr()
expr = an.windowize_function(self, frame)
if expr.equals(self):
raise com.IbisTypeError(
"No reduction or analytic function found to construct a window expression"
)
return expr

op = self.op()
if isinstance(op, ops.Alias):
return op.arg.to_expr().over(window).name(op.name)
elif isinstance(op, ops.WindowFunction):
if isinstance(op, ops.WindowFunction):
return op.func.to_expr().over(window)
elif isinstance(window, bl.WindowBuilder):
if table := an.find_first_base_table(self.op()):
return bind(table)
else:
return Deferred(Call(bind, _))
else:
return ops.WindowFunction(self, window).to_expr()
raise com.IbisTypeError("Unexpected window type: {window!r}")

def isnull(self) -> ir.BooleanValue:
"""Return whether this expression is NULL.
Expand Down
16 changes: 16 additions & 0 deletions ibis/tests/expr/test_window_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,19 @@ def metric(x):
expected = annual_delay[annual_delay, expr]

assert enriched.equals(expected)


def test_windowization_wraps_reduction_inside_a_nested_value_expression(alltypes):
t = alltypes
win = ibis.window(
following=0,
group_by=[t.g],
order_by=[t.a],
)
expr = (t.f == 0).notany().over(win)
assert expr.op() == ops.Not(
ops.WindowFunction(
func=ops.Any(t.f == 0),
frame=ops.RowsWindowFrame(table=t, end=0, group_by=[t.g], order_by=[t.a]),
)
)

0 comments on commit 772df36

Please sign in to comment.