Skip to content

Commit

Permalink
refactor(sql): simplify FirstValue/LastValue usage (ibis-project#…
Browse files Browse the repository at this point in the history
…8568)

During recent refactors the `FirstValue`/`LastValue` have been pretty
much removed from the codebase - at this point they only exist since
some backends require spelling `first`/`last` as
`first_value`/`last_value` when used in a window function context (other
backends alias `first`/`last` to these for nicer SQL UX).

This PR:

- Moves `FirstValue`/`LastValue` to `ibis/backends/sql/rewrites.py` to
treat them as SQL-specific IR.
- Refactors the `First -> FirstValue`/`Last -> LastValue` window
function rewrites to always apply for all SQL backends, making our
compilation more uniform here.
  • Loading branch information
jcrist committed Mar 6, 2024
1 parent 8919478 commit 6ed2e39
Show file tree
Hide file tree
Showing 19 changed files with 70 additions and 122 deletions.
4 changes: 0 additions & 4 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
Expand All @@ -33,8 +31,6 @@ class BigQueryCompiler(SQLGlotCompiler):
udf_type_mapper = BigQueryUDFType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_rank,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -30,8 +28,6 @@ class FlinkCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import (
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -25,8 +23,6 @@ class ImpalaCompiler(SQLGlotCompiler):
type_mapper = ImpalaType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
*SQLGlotCompiler.rewrites,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SELECT
`t0`.`k`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `last`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC NULLS LAST) AS `lag2`
FROM `alltypes` AS `t0`
4 changes: 0 additions & 4 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.deferred import var
Expand Down Expand Up @@ -66,8 +64,6 @@ class MSSQLCompiler(SQLGlotCompiler):
type_mapper = MSSQLType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
rewrite_rows_range_order_by_window,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.patterns import replace
Expand Down Expand Up @@ -53,8 +51,6 @@ class MySQLCompiler(SQLGlotCompiler):
rewrites = (
rewrite_limit,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from ibis.backends.sql.datatypes import OracleType
from ibis.backends.sql.dialects import Oracle
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -31,8 +31,6 @@ class OracleCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
replace_log2,
Expand Down Expand Up @@ -409,8 +407,8 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
ops.Correlation, # "corr",
ops.Count, # "count",
ops.Covariance, # "covar_pop", "covar_samp",
ops.FirstValue, # "first_value",
ops.LastValue, # "last_value",
FirstValue, # "first_value",
LastValue, # "last_value",
ops.Max, # "max",
ops.Min, # "min",
ops.NthValue, # "nth_value",
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,6 @@ def agg(df, order_keys):

return agg

@classmethod
def visit(cls, op: ops.FirstValue | ops.LastValue, arg):
i = 0 if isinstance(op, ops.FirstValue) else -1

def agg(df, order_keys):
return df[arg.name].iat[i]

return agg

@classmethod
def visit(
cls, op: ops.AnalyticVectorizedUDF, func, func_args, input_type, return_type
Expand Down
16 changes: 12 additions & 4 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, TRUE, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import p
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.common.patterns import replace
from ibis.config import options
from ibis.util import gen_name
Expand Down Expand Up @@ -222,6 +222,12 @@ def visit_CountDistinctStar(self, op, *, arg, where):
]
return self.f.count(sge.Distinct(expressions=cols))

def visit_FirstValue(self, op, *, arg):
return self.f.first(arg, TRUE)

def visit_LastValue(self, op, *, arg):
return self.f.last(arg, TRUE)

def visit_First(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
Expand Down Expand Up @@ -406,12 +412,14 @@ def visit_JSONGetItem(self, op, *, arg, index):
return self.f.get_json_object(arg, path)

def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
if isinstance(op.func, ops.Analytic):
# spark disallows specifying boundaries for lead/lag
if isinstance(op.func, ops.Analytic) and not isinstance(
op.func, (FirstValue, LastValue)
):
# spark disallows specifying boundaries for most window functions
if order_by:
order = sge.Order(expressions=order_by)
else:
# pyspark requires an order by clause for lag/lead
# pyspark requires an order by clause for most window functions
order = sge.Order(expressions=[NULL])
return sge.Window(this=func, partition_by=group_by, order=order)
else:
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)


Expand All @@ -39,8 +37,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
replace_log2,
replace_log10,
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
empty_in_values_right_side,
Expand Down Expand Up @@ -235,15 +237,15 @@ class SQLGlotCompiler(abc.ABC):
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
ops.FirstValue: "first_value",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
ops.LastValue: "last_value",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
ops.Log10: "log",
Expand Down
59 changes: 37 additions & 22 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def schema(self):
return Schema({k: v.dtype for k, v in self.selections.items()})


@public
class FirstValue(ops.Analytic):
"""Retrieve the first element."""

arg: ops.Column[dt.Any]

@attribute
def dtype(self):
return self.arg.dtype


@public
class LastValue(ops.Analytic):
"""Retrieve the last element."""

arg: ops.Column[dt.Any]

@attribute
def dtype(self):
return self.arg.dtype


@public
class Window(ops.Value):
"""Window modelled after SQL's window statements."""
Expand Down Expand Up @@ -100,10 +122,23 @@ def sort_to_select(_, **kwargs):

@replace(p.WindowFunction)
def window_function_to_window(_, **kwargs):
"""Convert a WindowFunction node to a Window node."""
"""Convert a WindowFunction node to a Window node.
Also rewrites first -> first_value, last -> last_value.
"""
func = _.func
if isinstance(func, (ops.First, ops.Last)):
if func.where is not None:
raise com.UnsupportedOperationError(
f"`{type(func).__name__.lower()}` with `where` is unsupported "
"in a window function"
)
cls = FirstValue if isinstance(func, ops.First) else LastValue
func = cls(func.arg)

return Window(
how=_.frame.how,
func=_.func,
func=func,
start=_.frame.start,
end=_.frame.end,
group_by=_.frame.group_by,
Expand Down Expand Up @@ -270,26 +305,6 @@ def rewrite_sample_as_filter(_, **kwargs):
return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),))


@replace(p.WindowFunction(p.First(x, where=y)))
def rewrite_first_to_first_value(_, x, y, **kwargs):
"""Rewrite Ibis's first to first_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`first` with `where` is unsupported in a window function"
)
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, where=y)))
def rewrite_last_to_last_value(_, x, y, **kwargs):
"""Rewrite Ibis's last to last_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`last` with `where` is unsupported in a window function"
)
return _.copy(func=ops.LastValue(x))


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y, **kwargs):
return _.copy(frame=y.copy(order_by=(ops.NULL,)))
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.datatypes import SQLiteType
from ibis.backends.sql.dialects import SQLite
from ibis.backends.sql.rewrites import (
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.backends.sql.rewrites import rewrite_sample_as_filter
from ibis.common.temporal import DateUnit, IntervalUnit


Expand All @@ -24,12 +20,7 @@ class SQLiteCompiler(SQLGlotCompiler):

dialect = SQLite
type_mapper = SQLiteType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

NAN = NULL
POS_INF = sge.Literal.number("1e999")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
any("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
anyLast("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
first_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
last_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Loading

0 comments on commit 6ed2e39

Please sign in to comment.