Skip to content

Commit

Permalink
feat(ux): allow window functions in predicates and compile to `QUALIF…
Browse files Browse the repository at this point in the history
…Y` where possible (#9787)
  • Loading branch information
cpcloud authored Aug 7, 2024
1 parent 8d4f97f commit 0370bcb
Show file tree
Hide file tree
Showing 28 changed files with 390 additions and 34 deletions.
18 changes: 15 additions & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ class SQLGlotCompiler(abc.ABC):
copy_func_args: bool = False
"""Whether to copy function arguments when generating SQL."""

supports_qualify: bool = False
"""Whether the backend supports the QUALIFY clause."""

NAN: ClassVar[sge.Expression] = sge.Cast(
this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE)
)
Expand Down Expand Up @@ -1249,15 +1252,21 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent

if selections:
if op.is_star_selection():
# if there are `qualify` predicates then sqlglot adds a hidden
# column to implement the functionality if the dialect doesn't
# support it
#
# using STAR in that case would lead to an extra column, so in that
# case we have to spell out the columns
if op.is_star_selection() and (not qualified or self.supports_qualify):
fields = [STAR]
else:
fields = self._cleanup_names(selections)
Expand All @@ -1266,6 +1275,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
if predicates:
result = result.where(*predicates, copy=False)

if qualified:
result = result.qualify(*qualified, copy=False)

if sort_keys:
result = result.order_by(*sort_keys, copy=False)

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class BigQueryCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

supports_qualify = True

UNSUPPORTED_OPS = (
ops.DateDiff,
ops.ExtractAuthority,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class ClickHouseCompiler(SQLGlotCompiler):

agg = ClickhouseAggGen()

supports_qualify = True

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class DuckDBCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True, supports_order_by=True)

supports_qualify = True

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent
Expand All @@ -492,6 +492,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
if predicates:
result = result.where(*predicates, copy=True)

if qualified:
result = result.qualify(*qualified, copy=True)

if sort_keys:
result = result.order_by(*sort_keys, copy=False)

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SnowflakeCompiler(SQLGlotCompiler):
dialect = Snowflake
type_mapper = SnowflakeType
no_limit_value = NULL
supports_qualify = True

agg = AggGen(supports_order_by=True)

Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,13 @@ def _create_sql(self, expression: sge.Create) -> str:
sge.Stddev: rename_func("stddev_pop"),
sge.ApproxDistinct: rename_func("approx_count_distinct"),
sge.Create: _create_sql,
sge.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
sge.Select: transforms.preprocess(
[
transforms.eliminate_semi_and_anti_joins,
transforms.eliminate_distinct_on,
transforms.eliminate_qualify,
]
),
sge.GroupConcat: rename_func("listagg"),
}

Expand Down
23 changes: 22 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Select(ops.Relation):
parent: ops.Relation
selections: FrozenDict[str, ops.Value] = {}
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
sort_keys: VarTuple[ops.SortKey] = ()

def is_star_selection(self):
Expand Down Expand Up @@ -99,10 +100,26 @@ def project_to_select(_, **kwargs):
return Select(_.parent, selections=_.values)


def partition_predicates(predicates):
qualified = []
unqualified = []

for predicate in predicates:
if predicate.find(ops.WindowFunction, filter=ops.Value):
qualified.append(predicate)
else:
unqualified.append(predicate)

return unqualified, qualified


@replace(p.Filter)
def filter_to_select(_, **kwargs):
"""Convert a Filter node to a Select node."""
return Select(_.parent, selections=_.values, predicates=_.predicates)
predicates, qualified = partition_predicates(_.predicates)
return Select(
_.parent, selections=_.values, predicates=predicates, qualified=qualified
)


@replace(p.Sort)
Expand Down Expand Up @@ -233,6 +250,9 @@ def merge_select_select(_, **kwargs):
predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates)
unique_predicates = toolz.unique(_.parent.predicates + predicates)

qualified = tuple(p.replace(subs, filter=ops.Value) for p in _.qualified)
unique_qualified = toolz.unique(_.parent.qualified + qualified)

sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
sort_key_exprs = {s.expr for s in sort_keys}
parent_sort_keys = tuple(
Expand All @@ -244,6 +264,7 @@ def merge_select_select(_, **kwargs):
_.parent.parent,
selections=selections,
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
)
return result if complexity(result) <= complexity(_) else _
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
*
FROM (
SELECT
`t0`.`x`,
SUM(`t0`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
FROM `t` AS `t0`
) AS `t1`
WHERE
`t1`.`y` <= 37
QUALIFY
AVG(`t1`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
*
FROM (
SELECT
"t0"."x" AS "x",
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
QUALIFY
isNotNull(AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
"x",
"y"
FROM (
SELECT
"t1"."x",
"t1"."y",
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
"x",
"y"
FROM (
SELECT
"t1"."x",
"t1"."y",
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
*
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
QUALIFY
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
"x",
"y"
FROM (
SELECT
"t1"."x",
"t1"."y",
AVG("t1"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SELECT
`t1`.`x`,
`t1`.`y`
FROM (
SELECT
`t0`.`x`,
SUM(`t0`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
FROM `t` AS `t0`
) AS `t1`
WHERE
`t1`.`y` <= 37
QUALIFY
AVG(`t1`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
`x`,
`y`
FROM (
SELECT
`t1`.`x`,
`t1`.`y`,
AVG(`t1`.`x`) OVER (ORDER BY NULL ASC) AS _w
FROM (
SELECT
`t0`.`x`,
SUM(`t0`.`x`) OVER (ORDER BY NULL ASC) AS `y`
FROM `t` AS `t0`
) AS `t1`
WHERE
`t1`.`y` <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
[x],
[y]
FROM (
SELECT
[t1].[x] AS [x],
[t1].[y] AS [y],
AVG([t1].[x]) OVER (ORDER BY CASE WHEN [t1].[x] IS NULL THEN 1 ELSE 0 END, [t1].[x] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
[t0].[x],
SUM([t0].[x]) OVER (ORDER BY CASE WHEN [t0].[x] IS NULL THEN 1 ELSE 0 END, [t0].[x] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS [y]
FROM [t] AS [t0]
) AS [t1]
WHERE
[t1].[y] <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
`x`,
`y`
FROM (
SELECT
`t1`.`x`,
`t1`.`y`,
AVG(`t1`.`x`) OVER (ORDER BY CASE WHEN NULL IS NULL THEN 1 ELSE 0 END, NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
`t0`.`x`,
SUM(`t0`.`x`) OVER (ORDER BY CASE WHEN NULL IS NULL THEN 1 ELSE 0 END, NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
FROM `t` AS `t0`
) AS `t1`
WHERE
`t1`.`y` <= 37
) AS _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
"x",
"y"
FROM (
SELECT
"t1"."x",
"t1"."y",
AVG("t1"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" "t0"
) "t1"
WHERE
"t1"."y" <= 37
) _t
WHERE
_w IS NOT NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SELECT
"x",
"y"
FROM (
SELECT
"t1"."x",
"t1"."y",
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
FROM (
SELECT
"t0"."x",
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
FROM "t" AS "t0"
) AS "t1"
WHERE
"t1"."y" <= 37
) AS _t
WHERE
_w IS NOT NULL
Loading

0 comments on commit 0370bcb

Please sign in to comment.