Skip to content

Commit

Permalink
refactor: always rewrite first -> first_value/last -> last_value
Browse files Browse the repository at this point in the history
…for SQL window functions
  • Loading branch information
jcrist committed Mar 6, 2024
1 parent 513242a commit 01e2ec9
Show file tree
Hide file tree
Showing 16 changed files with 40 additions and 91 deletions.
4 changes: 0 additions & 4 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
Expand All @@ -33,8 +31,6 @@ class BigQueryCompiler(SQLGlotCompiler):
udf_type_mapper = BigQueryUDFType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_rank,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -30,8 +28,6 @@ class FlinkCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import (
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -25,8 +23,6 @@ class ImpalaCompiler(SQLGlotCompiler):
type_mapper = ImpalaType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
*SQLGlotCompiler.rewrites,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SELECT
`t0`.`k`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `last`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC NULLS LAST) AS `lag2`
FROM `alltypes` AS `t0`
4 changes: 0 additions & 4 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.deferred import var
Expand Down Expand Up @@ -66,8 +64,6 @@ class MSSQLCompiler(SQLGlotCompiler):
type_mapper = MSSQLType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
rewrite_rows_range_order_by_window,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.patterns import replace
Expand Down Expand Up @@ -53,8 +51,6 @@ class MySQLCompiler(SQLGlotCompiler):
rewrites = (
rewrite_limit,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -33,8 +31,6 @@ class OracleCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
replace_log2,
Expand Down
16 changes: 12 additions & 4 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, TRUE, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import p
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.common.patterns import replace
from ibis.config import options
from ibis.util import gen_name
Expand Down Expand Up @@ -222,6 +222,12 @@ def visit_CountDistinctStar(self, op, *, arg, where):
]
return self.f.count(sge.Distinct(expressions=cols))

def visit_FirstValue(self, op, *, arg):
return self.f.first(arg, TRUE)

def visit_LastValue(self, op, *, arg):
return self.f.last(arg, TRUE)

def visit_First(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
Expand Down Expand Up @@ -406,12 +412,14 @@ def visit_JSONGetItem(self, op, *, arg, index):
return self.f.get_json_object(arg, path)

def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
if isinstance(op.func, ops.Analytic):
# spark disallows specifying boundaries for lead/lag
if isinstance(op.func, ops.Analytic) and not isinstance(
op.func, (FirstValue, LastValue)
):
# spark disallows specifying boundaries for most window functions
if order_by:
order = sge.Order(expressions=order_by)
else:
# pyspark requires an order by clause for lag/lead
# pyspark requires an order by clause for most window functions
order = sge.Order(expressions=[NULL])
return sge.Window(this=func, partition_by=group_by, order=order)
else:
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)


Expand All @@ -39,8 +37,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
replace_log2,
replace_log10,
Expand Down
37 changes: 15 additions & 22 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,23 @@ def sort_to_select(_, **kwargs):

@replace(p.WindowFunction)
def window_function_to_window(_, **kwargs):
"""Convert a WindowFunction node to a Window node."""
"""Convert a WindowFunction node to a Window node.
Also rewrites first -> first_value, last -> last_value.
"""
func = _.func
if isinstance(func, (ops.First, ops.Last)):
if func.where is not None:
raise com.UnsupportedOperationError(

Check warning on line 132 in ibis/backends/sql/rewrites.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/rewrites.py#L132

Added line #L132 was not covered by tests
f"`{type(func).__name__.lower()}` with `where` is unsupported "
"in a window function"
)
cls = FirstValue if isinstance(func, ops.First) else LastValue
func = cls(func.arg)

return Window(
how=_.frame.how,
func=_.func,
func=func,
start=_.frame.start,
end=_.frame.end,
group_by=_.frame.group_by,
Expand Down Expand Up @@ -292,26 +305,6 @@ def rewrite_sample_as_filter(_, **kwargs):
return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),))


@replace(p.WindowFunction(p.First(x, where=y)))
def rewrite_first_to_first_value(_, x, y, **kwargs):
"""Rewrite Ibis's first to first_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`first` with `where` is unsupported in a window function"
)
return _.copy(func=FirstValue(x))


@replace(p.WindowFunction(p.Last(x, where=y)))
def rewrite_last_to_last_value(_, x, y, **kwargs):
"""Rewrite Ibis's last to last_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`last` with `where` is unsupported in a window function"
)
return _.copy(func=LastValue(x))


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y, **kwargs):
return _.copy(frame=y.copy(order_by=(ops.NULL,)))
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.datatypes import SQLiteType
from ibis.backends.sql.dialects import SQLite
from ibis.backends.sql.rewrites import (
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.backends.sql.rewrites import rewrite_sample_as_filter
from ibis.common.temporal import DateUnit, IntervalUnit


Expand All @@ -24,12 +20,7 @@ class SQLiteCompiler(SQLGlotCompiler):

dialect = SQLite
type_mapper = SQLiteType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

NAN = NULL
POS_INF = sge.Literal.number("1e999")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
any("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
anyLast("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
first_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
last_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
8 changes: 1 addition & 7 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,11 @@ def calc_zscore(s):
lambda t, win: t.float_col.first().over(win),
lambda t: t.float_col.transform("first"),
id="first",
marks=[
pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError),
],
),
param(
lambda t, win: t.float_col.last().over(win),
lambda t: t.float_col.transform("last"),
id="last",
marks=[
pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError),
],
),
param(
lambda t, win: t.double_col.nth(3).over(win),
Expand Down Expand Up @@ -1073,7 +1067,7 @@ def test_mutate_window_filter(backend, alltypes):
backend.assert_frame_equal(res, sol, check_dtype=False)


@pytest.mark.notimpl(["polars", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
def test_first_last(backend):
t = backend.win
w = ibis.window(group_by=t.g, order_by=[t.x, t.y], preceding=1, following=0)
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,15 @@
)
from ibis.backends.sql.datatypes import TrinoType
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops


class TrinoCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = Trino
type_mapper = TrinoType
rewrites = (
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
rewrites = (exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites)
quoted = True

NAN = sg.func("nan")
Expand Down

0 comments on commit 01e2ec9

Please sign in to comment.