Skip to content

Commit

Permalink
feat(sql): fuse distinct with other select nodes when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 3, 2024
1 parent 7df4bdd commit c31412b
Show file tree
Hide file tree
Showing 26 changed files with 379 additions and 38 deletions.
20 changes: 12 additions & 8 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ class SQLGlotCompiler(abc.ABC):
one_to_zero_index,
add_one_to_nth_value_input,
)
"""A sequence of rewrites to apply to the expression tree before compilation."""
"""A sequence of rewrites to apply to the expression tree before SQL-specific transforms."""

post_rewrites: tuple[type[pats.Replace], ...] = ()
"""A sequence of rewrites to apply to the expression tree after SQL-specific transforms."""

no_limit_value: sge.Null | None = None
"""The value to use to indicate no limit."""
Expand Down Expand Up @@ -606,6 +609,7 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
op,
params=params,
rewrites=self.rewrites,
post_rewrites=self.post_rewrites,
fuse_selects=options.sql.fuse_selects,
)

Expand Down Expand Up @@ -1257,9 +1261,11 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -1286,6 +1292,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_DummyTable(self, op, *, values):
Expand Down Expand Up @@ -1470,11 +1479,6 @@ def visit_Limit(self, op, *, parent, n, offset):
return result.subquery(alias, copy=False)
return result

def visit_Distinct(self, op, *, parent):
return (
sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False)
)

def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
split_select_distinct_with_order_by,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit

Expand Down Expand Up @@ -113,6 +114,7 @@ class BigQueryCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_rank,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)

supports_qualify = True

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType

Expand All @@ -26,6 +27,8 @@ class DataFusionCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True, supports_order_by=True)

post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
split_select_distinct_with_order_by,
)
from ibis.common.deferred import var

Expand Down Expand Up @@ -69,6 +70,7 @@ class MSSQLCompiler(SQLGlotCompiler):
rewrite_rows_range_order_by_window,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
copy_func_args = True

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -479,9 +481,11 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -500,6 +504,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_TimestampAdd(self, op, *, left, right):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -41,6 +42,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = Postgres
type_mapper = PostgresType
post_rewrites = (split_select_distinct_with_order_by,)

agg = AggGen(supports_filter=True, supports_order_by=True)

Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
p,
split_select_distinct_with_order_by,
)
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -51,6 +56,7 @@ class PySparkCompiler(SQLGlotCompiler):
dialect = PySpark
type_mapper = PySparkType
rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites)
post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.RowID,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
split_select_distinct_with_order_by,
)
from ibis.util import gen_name

Expand All @@ -39,6 +40,7 @@ class TrinoCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
quoted = True

NAN = sg.func("nan")
Expand Down
99 changes: 98 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Select(ops.Relation):
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
sort_keys: VarTuple[ops.SortKey] = ()
distinct: bool = False

def is_star_selection(self):
return tuple(self.values.items()) == tuple(self.parent.fields.items())
Expand Down Expand Up @@ -128,6 +129,12 @@ def sort_to_select(_, **kwargs):
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.Distinct)
def distinct_to_select(_, **kwargs):
"""Convert a Distinct node to a Select node."""
return Select(_.parent, selections=_.values, distinct=True)


@replace(p.DropColumns)
def drop_columns_to_select(_, **kwargs):
"""Convert a DropColumns node to a Select node."""
Expand Down Expand Up @@ -244,6 +251,48 @@ def merge_select_select(_, **kwargs):
if _.parent.find_below(blocking, filter=ops.Value):
return _

if _.parent.distinct:
# The inner query is distinct.
#
# If the outer query is distinct, it's only safe to merge if it's a simple subselection:
# - Fusing in the presence of non-deterministic calls in the select would lead to
# incorrect results
# - Fusing in the presence of expensive calls in the select would lead to potential
# performance pitfalls
if _.distinct and not all(
isinstance(v, ops.Field) for v in _.selections.values()
):
return _

# If the outer query isn't distinct, it's only safe to merge if the outer is a SELECT *:
# - If new columns are added, they might be non-distinct, changing the distinctness
# - If previous columns are removed, that would also change the distinctness
if not _.distinct and not _.is_star_selection():
return _

distinct = True
elif _.distinct:
# The outer query is distinct and the inner isn't. It's only safe to merge if either
# - The inner query isn't ordered
# - The outer query is a SELECT *
#
# Otherwise we run the risk that the outer query drops columns needed for the ordering of
# the inner query - many backends don't allow select distinc queries to order by columns
# that aren't present in their selection, like
#
# SELECT DISTINCT a, b FROM t ORDER BY c --- some backends will explode at this
#
# An alternate solution would be to drop the inner ORDER BY clause, since the backend will
# ignore it anyway since it's a subquery. That feels potentially risky though, better
# to generate the SQL as written.
if _.parent.sort_keys and not _.is_star_selection():
return _

distinct = True
else:
# Neither query is distinct, safe to merge
distinct = False

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}

Expand All @@ -266,6 +315,7 @@ def merge_select_select(_, **kwargs):
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
distinct=distinct,
)
return result if complexity(result) <= complexity(_) else _

Expand All @@ -289,6 +339,7 @@ def sqlize(
node: ops.Node,
params: Mapping[ops.ScalarParameter, Any],
rewrites: Sequence[Pattern] = (),
post_rewrites: Sequence[Pattern] = (),
fuse_selects: bool = True,
) -> tuple[ops.Node, list[ops.Node]]:
"""Lower the ibis expression graph to a SQL-like relational algebra.
Expand All @@ -300,7 +351,9 @@ def sqlize(
params
A mapping of scalar parameters to their values.
rewrites
Supplementary rewrites to apply to the expression graph.
Supplementary rewrites to apply before SQL-specific transforms.
post_rewrites
Supplementary rewrites to apply after SQL-specific transforms.
fuse_selects
Whether to merge subsequent Select nodes into one where possible.
Expand All @@ -322,6 +375,7 @@ def sqlize(
| project_to_select
| filter_to_select
| sort_to_select
| distinct_to_select
| fill_null_to_select
| drop_null_to_select
| drop_columns_to_select
Expand All @@ -335,6 +389,9 @@ def sqlize(
else:
simplified = sqlized

if post_rewrites:
simplified = simplified.replace(reduce(operator.or_, post_rewrites))

# extract common table expressions while wrapping them in a CTE node
ctes = extract_ctes(simplified)

Expand All @@ -351,6 +408,46 @@ def wrap(node, _, **kwargs):
# supplemental rewrites selectively used on a per-backend basis


@replace(Select)
def split_select_distinct_with_order_by(_):
"""Split a `SELECT DISTINCT ... ORDER BY` query when needed.
Some databases (postgres, pyspark, ...) have issues with two types of
ordered select distinct statements:
```
--- ORDER BY with an expression instead of a name in the select list
SELECT DISTINCT a, b FROM t ORDER BY a + 1
--- ORDER BY using a qualified column name, rather than the alias in the select list
SELECT DISTINCT a, b as x FROM t ORDER BY b --- or t.b
```
We solve both these cases by splitting everything except the `ORDER BY`
into a subquery.
```
SELECT DISTINCT a, b FROM t WHERE a > 10 ORDER BY a + 1
--- is rewritten as ->
SELECT * FROM (SELECT DISTINCT a, b FROM t WHERE a > 10) ORDER BY a + 1
```
"""
# risingwave and pyspark also don't allow qualified names as sort keys, like
# SELECT DISTINCT t.a FROM t ORDER BY t.a
# To avoid having specific rewrite rules for these backends to use only
# local names, we always split SELECT DISTINCT from ORDER BY here. Otherwise we
# could also avoid splitting if all sort keys appear in the select list.
if _.distinct and _.sort_keys:
inner = _.copy(sort_keys=())
subs = {v: ops.Field(inner, k) for k, v in inner.values.items()}
sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
selections = {
k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()
}
return Select(inner, selections=selections, sort_keys=sort_keys)
return _


@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10 AND "t0"."a" > 10
ORDER BY
"t0"."a" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10 AND "t0"."a" > 10
Loading

0 comments on commit c31412b

Please sign in to comment.