Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sql): simplify FirstValue/LastValue usage #8568

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
Expand All @@ -33,8 +31,6 @@ class BigQueryCompiler(SQLGlotCompiler):
udf_type_mapper = BigQueryUDFType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_rank,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -30,8 +28,6 @@ class FlinkCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import (
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -25,8 +23,6 @@ class ImpalaCompiler(SQLGlotCompiler):
type_mapper = ImpalaType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
*SQLGlotCompiler.rewrites,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SELECT
`t0`.`k`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g`) AS `last`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC NULLS LAST) AS `lag2`
FROM `alltypes` AS `t0`
4 changes: 0 additions & 4 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.deferred import var
Expand Down Expand Up @@ -66,8 +64,6 @@ class MSSQLCompiler(SQLGlotCompiler):
type_mapper = MSSQLType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
rewrite_rows_range_order_by_window,
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.patterns import replace
Expand Down Expand Up @@ -53,8 +51,6 @@ class MySQLCompiler(SQLGlotCompiler):
rewrites = (
rewrite_limit,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from ibis.backends.sql.datatypes import OracleType
from ibis.backends.sql.dialects import Oracle
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)

Expand All @@ -31,8 +31,6 @@ class OracleCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
replace_log2,
Expand Down Expand Up @@ -409,8 +407,8 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
ops.Correlation, # "corr",
ops.Count, # "count",
ops.Covariance, # "covar_pop", "covar_samp",
ops.FirstValue, # "first_value",
ops.LastValue, # "last_value",
FirstValue, # "first_value",
LastValue, # "last_value",
ops.Max, # "max",
ops.Min, # "min",
ops.NthValue, # "nth_value",
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,6 @@ def agg(df, order_keys):

return agg

@classmethod
def visit(cls, op: ops.FirstValue | ops.LastValue, arg):
i = 0 if isinstance(op, ops.FirstValue) else -1

def agg(df, order_keys):
return df[arg.name].iat[i]

return agg

@classmethod
def visit(
cls, op: ops.AnalyticVectorizedUDF, func, func_args, input_type, return_type
Expand Down
16 changes: 12 additions & 4 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, TRUE, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import p
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.common.patterns import replace
from ibis.config import options
from ibis.util import gen_name
Expand Down Expand Up @@ -222,6 +222,12 @@ def visit_CountDistinctStar(self, op, *, arg, where):
]
return self.f.count(sge.Distinct(expressions=cols))

def visit_FirstValue(self, op, *, arg):
return self.f.first(arg, TRUE)

def visit_LastValue(self, op, *, arg):
return self.f.last(arg, TRUE)

def visit_First(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
Expand Down Expand Up @@ -406,12 +412,14 @@ def visit_JSONGetItem(self, op, *, arg, index):
return self.f.get_json_object(arg, path)

def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
if isinstance(op.func, ops.Analytic):
# spark disallows specifying boundaries for lead/lag
if isinstance(op.func, ops.Analytic) and not isinstance(
op.func, (FirstValue, LastValue)
):
# spark disallows specifying boundaries for most window functions
if order_by:
order = sge.Order(expressions=order_by)
else:
# pyspark requires an order by clause for lag/lead
# pyspark requires an order by clause for most window functions
order = sge.Order(expressions=[NULL])
return sge.Window(this=func, partition_by=group_by, order=order)
else:
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)


Expand All @@ -39,8 +37,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
replace_log2,
replace_log10,
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
empty_in_values_right_side,
Expand Down Expand Up @@ -235,15 +237,15 @@ class SQLGlotCompiler(abc.ABC):
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
ops.FirstValue: "first_value",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
ops.LastValue: "last_value",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
ops.Log10: "log",
Expand Down
59 changes: 37 additions & 22 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@
return Schema({k: v.dtype for k, v in self.selections.items()})


@public
class FirstValue(ops.Analytic):
"""Retrieve the first element."""

arg: ops.Column[dt.Any]

@attribute
def dtype(self):
return self.arg.dtype


@public
class LastValue(ops.Analytic):
"""Retrieve the last element."""

arg: ops.Column[dt.Any]

@attribute
def dtype(self):
return self.arg.dtype


@public
class Window(ops.Value):
"""Window modelled after SQL's window statements."""
Expand Down Expand Up @@ -100,10 +122,23 @@

@replace(p.WindowFunction)
def window_function_to_window(_, **kwargs):
"""Convert a WindowFunction node to a Window node."""
"""Convert a WindowFunction node to a Window node.

Also rewrites first -> first_value, last -> last_value.
"""
func = _.func
if isinstance(func, (ops.First, ops.Last)):
if func.where is not None:
jcrist marked this conversation as resolved.
Show resolved Hide resolved
raise com.UnsupportedOperationError(

Check warning on line 132 in ibis/backends/sql/rewrites.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/rewrites.py#L132

Added line #L132 was not covered by tests
f"`{type(func).__name__.lower()}` with `where` is unsupported "
"in a window function"
)
cls = FirstValue if isinstance(func, ops.First) else LastValue
func = cls(func.arg)

return Window(
how=_.frame.how,
func=_.func,
func=func,
start=_.frame.start,
end=_.frame.end,
group_by=_.frame.group_by,
Expand Down Expand Up @@ -270,26 +305,6 @@
return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),))


@replace(p.WindowFunction(p.First(x, where=y)))
def rewrite_first_to_first_value(_, x, y, **kwargs):
"""Rewrite Ibis's first to first_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`first` with `where` is unsupported in a window function"
)
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, where=y)))
def rewrite_last_to_last_value(_, x, y, **kwargs):
"""Rewrite Ibis's last to last_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
"`last` with `where` is unsupported in a window function"
)
return _.copy(func=ops.LastValue(x))


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y, **kwargs):
return _.copy(frame=y.copy(order_by=(ops.NULL,)))
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.datatypes import SQLiteType
from ibis.backends.sql.dialects import SQLite
from ibis.backends.sql.rewrites import (
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.backends.sql.rewrites import rewrite_sample_as_filter
from ibis.common.temporal import DateUnit, IntervalUnit


Expand All @@ -24,12 +20,7 @@ class SQLiteCompiler(SQLGlotCompiler):

dialect = SQLite
type_mapper = SQLiteType
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

NAN = NULL
POS_INF = sge.Literal.number("1e999")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
any("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
anyLast("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
first_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
last_value("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees"
FROM (
SELECT
"t1"."field_of_study",
Expand Down
Loading
Loading