Skip to content

Commit

Permalink
feat(sqlalchemy): support expressions in window bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Nov 21, 2023
1 parent da2a699 commit 5dbb3b1
Showing 1 changed file with 82 additions and 9 deletions.
91 changes: 82 additions & 9 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED
from sqlalchemy.sql.functions import FunctionElement, GenericFunction

import ibis.common.exceptions as com
Expand Down Expand Up @@ -308,17 +309,89 @@ def _endswith(t, op):
return t.translate(op.arg).endswith(t.translate(op.end))


def _translate_window_boundary(boundary):
def _reinterpret_range_bound(bound):
if bound is None:
return RANGE_UNBOUNDED

try:
lower = int(bound)
except ValueError as err:
sa.util.raise_(
sa.exc.ArgumentError(
"Integer, None or expression expected for range value"
),
replace_context=err,
)
except TypeError:
return bound
else:
return RANGE_CURRENT if lower == 0 else lower


def _interpret_range(self, range_):
if not isinstance(range_, tuple) or len(range_) != 2:
raise sa.exc.ArgumentError("2-tuple expected for range/rows")

lower = _reinterpret_range_bound(range_[0])
upper = _reinterpret_range_bound(range_[1])
return lower, upper


# monkeypatch to allow expressions in range and rows bounds
sa.sql.elements.Over._interpret_range = _interpret_range


def _translate_window_boundary(t, boundary):
if boundary is None:
return None

if isinstance(boundary.value, ops.Literal):
if boundary.preceding:
return -boundary.value.value
else:
return boundary.value.value
value = t.translate(boundary.value)
return value if boundary.preceding else value


def _compile_bounds(compiler, range_, kind, **kw):
left_, right_ = range_

if left_ is RANGE_UNBOUNDED:
left = "UNBOUNDED PRECEDING"
elif left_ is RANGE_CURRENT:
left = "CURRENT ROW"
else:
left = f"{compiler.process(left_, **kw)} PRECEDING"

if right_ is RANGE_UNBOUNDED:
right = "UNBOUNDED FOLLOWING"
elif right_ is RANGE_CURRENT:
right = "CURRENT ROW"
else:
right = f"{compiler.process(right_, **kw)} FOLLOWING"

return f"{kind} BETWEEN {left} AND {right}"


@compiles(sa.sql.elements.Over)
def compile_over(over, compiler, **kw):
text = compiler.process(over.element, **kw)
if over.range_:
range_ = _compile_bounds(compiler, over.range_, kind="RANGE", **kw)
elif over.rows:
range_ = _compile_bounds(compiler, over.rows, kind="ROWS", **kw)
else:
range_ = None

args = [
f"{word} BY {compiler.process(clause, **kw)}"
for word, clause in (
("PARTITION", over.partition_by),
("ORDER", over.order_by),
)
if clause is not None and len(clause)
]

if range_ is not None:
args.append(range_)

raise com.TranslationError("Window boundaries must be literal values")
return f"{text} OVER ({' '.join(args)})"


def _window_function(t, window):
Expand Down Expand Up @@ -351,8 +424,8 @@ def _window_function(t, window):
# some functions on some backends don't support frame clauses
additional_params = {}
else:
start = _translate_window_boundary(window.frame.start)
end = _translate_window_boundary(window.frame.end)
start = _translate_window_boundary(t, window.frame.start)
end = _translate_window_boundary(t, window.frame.end)
additional_params = {how: (start, end)}

result = sa.over(
Expand Down

0 comments on commit 5dbb3b1

Please sign in to comment.