Skip to content

Commit

Permalink
refactor(ir/api): introduce window frame operation and revamp the win…
Browse files Browse the repository at this point in the history
…dow API
  • Loading branch information
kszucs committed Feb 4, 2023
1 parent 3cb7682 commit 2bc5e5e
Show file tree
Hide file tree
Showing 75 changed files with 1,741 additions and 1,801 deletions.
88 changes: 46 additions & 42 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.expr.window as W
from ibis.backends.base.sql.alchemy.database import AlchemyTable


Expand Down Expand Up @@ -300,69 +299,73 @@ def _endswith(t, op):
}


def _cumulative_to_window(translator, op, window):
win = W.cumulative_window()
win = win.group_by(window._group_by).order_by(window._order_by)

def _cumulative_to_window(translator, op, frame):
klass = _cumulative_to_reduction[type(op)]
new_op = klass(*op.args)
new_expr = new_op.to_expr().name(op.name)
new_frame = frame.copy(start=None, end=0)

if type(new_op) in translator._rewrites:
new_expr = translator._rewrites[type(new_op)](new_expr)

# TODO(kszucs): rewrite to receive and return an ops.Node
return an.windowize_function(new_expr, win)
return an.windowize_function(new_expr, frame=new_frame)


def _window(t, op):
arg, window = op.args
reduction = t.translate(arg)
def _translate_window_boundary(boundary):
if boundary is None:
return None

window_op = arg
if isinstance(boundary.value, ops.Literal):
if boundary.preceding:
return -boundary.value.value
else:
return boundary.value.value

if isinstance(window_op, ops.CumulativeOp):
arg = _cumulative_to_window(t, arg, window).op()
return t.translate(arg)
raise com.TranslationError("Window boundaries must be literal values")

if window.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented '
'for SQLAlchemy-based backends.'
)

# Checks for invalid user input e.g. passing in tuple for preceding and
# non-None value for following are caught and raised in expr/window.py
# if we're here, then the input is valid, we just need to interpret it
# correctly
if isinstance(window.preceding, tuple):
start, end = (-1 * x if x is not None else None for x in window.preceding)
elif isinstance(window.following, tuple):
start, end = window.following
else:
start = -window.preceding if window.preceding is not None else window.preceding
end = window.following
def _window_function(t, window):
if isinstance(window.func, ops.CumulativeOp):
func = _cumulative_to_window(t, window.func, window.frame).op()
return t.translate(func)

reduction = t.translate(window.func)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(window_op, t._require_order_by) and not window._order_by:
order_by = t.translate(window_op.args[0])
if isinstance(window.func, t._require_order_by) and not window.frame.order_by:
order_by = t.translate(window.func.arg) # .args[0])
else:
order_by = [t.translate(arg) for arg in window._order_by]
order_by = [t.translate(arg) for arg in window.frame.order_by]

partition_by = [t.translate(arg) for arg in window._group_by]
partition_by = [t.translate(arg) for arg in window.frame.group_by]

if isinstance(window.frame, ops.RowsWindowFrame):
if window.frame.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for SQLAlchemy-based '
'backends.'
)
how = 'rows'
elif isinstance(window.frame, ops.RangeWindowFrame):
how = 'range_'
else:
raise NotImplementedError(type(window.frame))

if t._forbids_frame_clause and isinstance(window.func, t._forbids_frame_clause):
# some functions on some backends don't support frame clauses
additional_params = {}
else:
start = _translate_window_boundary(window.frame.start)
end = _translate_window_boundary(window.frame.end)
additional_params = {how: (start, end)}

how = {'range': 'range_'}.get(window.how, window.how)
additional_params = (
{}
if t._forbids_frame_clause and isinstance(window_op, t._forbids_frame_clause)
else {how: (start, end)}
)
result = reduction.over(
partition_by=partition_by, order_by=order_by, **additional_params
)

if isinstance(window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
if isinstance(window.func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
return result - 1
else:
return result
Expand Down Expand Up @@ -622,12 +625,13 @@ def translator(t, op: ops.Node):
ops.PercentRank: unary(lambda _: sa.func.percent_rank()),
ops.CumeDist: unary(lambda _: sa.func.cume_dist()),
ops.NthValue: _nth_value,
ops.Window: _window,
ops.CumulativeOp: _window,
ops.WindowFunction: _window_function,
ops.CumulativeMax: unary(sa.func.max),
ops.CumulativeMin: unary(sa.func.min),
ops.CumulativeSum: unary(sa.func.sum),
ops.CumulativeMean: unary(sa.func.avg),
ops.CumulativeAny: unary(sa.func.bool_or),
ops.CumulativeAll: unary(sa.func.bool_and),
}

geospatial_functions = {
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from ibis.backends.base.sql.registry.window import (
cumulative_to_window,
format_window,
format_window_frame,
time_range_to_range_window,
)

Expand All @@ -31,6 +31,6 @@
'reduction',
'unary',
'cumulative_to_window',
'format_window',
'format_window_frame',
'time_range_to_range_window',
)
Loading

0 comments on commit 2bc5e5e

Please sign in to comment.