Skip to content

Commit

Permalink
feat(ux): allow window functions in predicates and compile to `QUALIF…
Browse files Browse the repository at this point in the history
…Y` where possible
  • Loading branch information
cpcloud committed Aug 7, 2024
1 parent 17f8632 commit f30e077
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 27 deletions.
13 changes: 11 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ class SQLGlotCompiler(abc.ABC):
copy_func_args: bool = False
"""Whether to copy function arguments when generating SQL."""

supports_qualify: bool = False
"""Whether the backend supports the QUALIFY clause."""

NAN: ClassVar[sge.Expression] = sge.Cast(
this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE)
)
Expand Down Expand Up @@ -1254,13 +1257,19 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent

if selections:
if op.is_star_selection():
# if there are `qualify` predicates then sqlglot adds a hidden
# column to implement the functionality if the dialect doesn't
# support it
#
# using STAR in that case would lead to an extra column, so in that
# case we have to spell out the columns
if op.is_star_selection() and (not qualified or self.supports_qualify):
fields = [STAR]
else:
fields = self._cleanup_names(selections)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class BigQueryCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

supports_qualify = True

UNSUPPORTED_OPS = (
ops.DateDiff,
ops.ExtractAuthority,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class ClickHouseCompiler(SQLGlotCompiler):

agg = ClickhouseAggGen()

supports_qualify = True

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class DuckDBCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

supports_qualify = True

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,9 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent
Expand All @@ -495,6 +495,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
if predicates:
result = result.where(*predicates, copy=True)

if qualified:
result = result.qualify(*qualified, copy=True)

if sort_keys:
result = result.order_by(*sort_keys, copy=False)

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SnowflakeCompiler(SQLGlotCompiler):
dialect = Snowflake
type_mapper = SnowflakeType
no_limit_value = NULL
supports_qualify = True

agg = AggGen(supports_order_by=True)

Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,13 @@ def _create_sql(self, expression: sge.Create) -> str:
sge.Stddev: rename_func("stddev_pop"),
sge.ApproxDistinct: rename_func("approx_count_distinct"),
sge.Create: _create_sql,
sge.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
sge.Select: transforms.preprocess(
[
transforms.eliminate_semi_and_anti_joins,
transforms.eliminate_distinct_on,
transforms.eliminate_qualify,
]
),
sge.GroupConcat: rename_func("listagg"),
}

Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def partition_predicates(predicates):
unqualified = []

for predicate in predicates:
if predicate.find(ops.WindowFunction):
if predicate.find(ops.WindowFunction, filter=ops.Value):
qualified.append(predicate)
else:
unqualified.append(predicate)
Expand Down Expand Up @@ -250,6 +250,9 @@ def merge_select_select(_, **kwargs):
predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates)
unique_predicates = toolz.unique(_.parent.predicates + predicates)

qualified = tuple(p.replace(subs, filter=ops.Value) for p in _.qualified)
unique_qualified = toolz.unique(_.parent.qualified + qualified)

sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
sort_key_exprs = {s.expr for s in sort_keys}
parent_sort_keys = tuple(
Expand All @@ -261,6 +264,7 @@ def merge_select_select(_, **kwargs):
_.parent.parent,
selections=selections,
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
)
return result if complexity(result) <= complexity(_) else _
Expand Down
32 changes: 11 additions & 21 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,27 +345,17 @@ def test_filter(backend, alltypes, sorted_df, predicate_fn, expected_fn):


@pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"datafusion",
"duckdb",
"impala",
"mysql",
"postgres",
"risingwave",
"sqlite",
"snowflake",
"polars",
"mssql",
"trino",
"druid",
"oracle",
"exasol",
"pandas",
"pyspark",
"dask",
]
["impala", "polars", "druid", "exasol", "pandas", "pyspark", "dask"]
)
@pytest.mark.notyet(
["oracle"],
raises=OracleDatabaseError,
reason="sqlglot `eliminate_qualify` transform produces underscores in aliases, which is not allowed by oracle",
)
@pytest.mark.never(
["mssql"],
raises=PyODBCProgrammingError,
reason="sqlglot transform produces an order by in a subquery, which is not allowed by mssql",
)
@pytest.mark.never(
["flink"],
Expand Down

0 comments on commit f30e077

Please sign in to comment.