diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index 307172b4970f..ff13c63ee2a1 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED from sqlalchemy.sql.functions import FunctionElement, GenericFunction import ibis.common.exceptions as com @@ -308,17 +309,89 @@ def _endswith(t, op): return t.translate(op.arg).endswith(t.translate(op.end)) -def _translate_window_boundary(boundary): +def _reinterpret_range_bound(bound): + if bound is None: + return RANGE_UNBOUNDED + + try: + lower = int(bound) + except ValueError as err: + sa.util.raise_( + sa.exc.ArgumentError( + "Integer, None or expression expected for range value" + ), + replace_context=err, + ) + except TypeError: + return bound + else: + return RANGE_CURRENT if lower == 0 else lower + + +def _interpret_range(self, range_): + if not isinstance(range_, tuple) or len(range_) != 2: + raise sa.exc.ArgumentError("2-tuple expected for range/rows") + + lower = _reinterpret_range_bound(range_[0]) + upper = _reinterpret_range_bound(range_[1]) + return lower, upper + + +# monkeypatch to allow expressions in range and rows bounds +sa.sql.elements.Over._interpret_range = _interpret_range + + +def _translate_window_boundary(t, boundary): if boundary is None: return None - if isinstance(boundary.value, ops.Literal): - if boundary.preceding: - return -boundary.value.value - else: - return boundary.value.value + value = t.translate(boundary.value) + return value if boundary.preceding else value + + +def _compile_bounds(compiler, range_, kind, **kw): + left_, right_ = range_ + + if left_ is RANGE_UNBOUNDED: + left = "UNBOUNDED PRECEDING" + elif left_ is RANGE_CURRENT: + left = "CURRENT ROW" + else: + left = f"{compiler.process(left_, **kw)} PRECEDING" + + if right_ is RANGE_UNBOUNDED: + right = "UNBOUNDED FOLLOWING" + elif right_ is RANGE_CURRENT: + right = "CURRENT ROW" + else: + right = f"{compiler.process(right_, **kw)} FOLLOWING" + + return f"{kind} BETWEEN {left} AND {right}" + + +@compiles(sa.sql.elements.Over) +def compile_over(over, compiler, **kw): + text = compiler.process(over.element, **kw) + if over.range_: + range_ = _compile_bounds(compiler, over.range_, kind="RANGE", **kw) + elif over.rows: + range_ = _compile_bounds(compiler, over.rows, kind="ROWS", **kw) + else: + range_ = None + + args = [ + f"{word} BY {compiler.process(clause, **kw)}" + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + if range_ is not None: + args.append(range_) - raise com.TranslationError("Window boundaries must be literal values") + return f"{text} OVER ({' '.join(args)})" def _window_function(t, window): @@ -351,8 +424,8 @@ def _window_function(t, window): # some functions on some backends don't support frame clauses additional_params = {} else: - start = _translate_window_boundary(window.frame.start) - end = _translate_window_boundary(window.frame.end) + start = _translate_window_boundary(t, window.frame.start) + end = _translate_window_boundary(t, window.frame.end) additional_params = {how: (start, end)} result = sa.over(