From c31412b16113eef868af4fe47825fbefdde5c1de Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 26 Aug 2024 10:35:10 -0500 Subject: [PATCH] feat(sql): fuse `distinct` with other select nodes when possible --- ibis/backends/sql/compilers/base.py | 20 ++-- .../sql/compilers/bigquery/__init__.py | 2 + ibis/backends/sql/compilers/datafusion.py | 3 + ibis/backends/sql/compilers/mssql.py | 11 ++- ibis/backends/sql/compilers/postgres.py | 2 + ibis/backends/sql/compilers/pyspark.py | 8 +- ibis/backends/sql/compilers/trino.py | 2 + ibis/backends/sql/rewrites.py | 99 ++++++++++++++++++- .../test_column_distinct/out.sql | 8 +- .../test_compiler/test_table_distinct/out.sql | 10 +- .../distinct-filter-order_by/out.sql | 9 ++ .../distinct-filter/out.sql | 7 ++ .../out.sql | 12 +++ .../distinct-non-trivial-select/out.sql | 12 +++ .../distinct-select-distinct/out.sql | 12 +++ .../distinct-select/out.sql | 12 +++ .../test_fuse_distinct/distinct/out.sql | 7 ++ .../non-trivial-select-distinct/out.sql | 6 ++ .../order_by-distinct-drop/out.sql | 14 +++ .../order_by-distinct/out.sql | 9 ++ .../order_by-drop-distinct/out.sql | 8 ++ .../select-distinct/out.sql | 6 ++ .../test_distinct/projection_distinct/out.sql | 10 +- .../single_column_projection_distinct/out.sql | 8 +- ibis/backends/tests/sql/test_select_sql.py | 61 ++++++++++++ ibis/backends/tests/test_generic.py | 59 +++++++++++ 26 files changed, 379 insertions(+), 38 deletions(-) create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter-order_by/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select-distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select-distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/non-trivial-select-distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct-drop/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-drop-distinct/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/select-distinct/out.sql diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 2558a0e65766..687b44b270f6 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -266,7 +266,10 @@ class SQLGlotCompiler(abc.ABC): one_to_zero_index, add_one_to_nth_value_input, ) - """A sequence of rewrites to apply to the expression tree before compilation.""" + """A sequence of rewrites to apply to the expression tree before SQL-specific transforms.""" + + post_rewrites: tuple[type[pats.Replace], ...] = () + """A sequence of rewrites to apply to the expression tree after SQL-specific transforms.""" no_limit_value: sge.Null | None = None """The value to use to indicate no limit.""" @@ -606,6 +609,7 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: op, params=params, rewrites=self.rewrites, + post_rewrites=self.post_rewrites, fuse_selects=options.sql.fuse_selects, ) @@ -1257,9 +1261,11 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): else: yield value.as_(name, quoted=self.quoted, copy=False) - def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys): + def visit_Select( + self, op, *, parent, selections, predicates, qualified, sort_keys, distinct + ): # if we've constructed a useless projection return the parent relation - if not (selections or predicates or qualified or sort_keys): + if not (selections or predicates or qualified or sort_keys or distinct): return parent result = parent @@ -1286,6 +1292,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke if sort_keys: result = result.order_by(*sort_keys, copy=False) + if distinct: + result = result.distinct() + return result def visit_DummyTable(self, op, *, values): @@ -1470,11 +1479,6 @@ def visit_Limit(self, op, *, parent, n, offset): return result.subquery(alias, copy=False) return result - def visit_Distinct(self, op, *, parent): - return ( - sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False) - ) - def visit_CTE(self, op, *, parent): return sg.table(parent.alias_or_name, quoted=self.quoted) diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index e803734d1f59..fcb3f3c0d873 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -22,6 +22,7 @@ exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, + split_select_distinct_with_order_by, ) from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit @@ -113,6 +114,7 @@ class BigQueryCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_rank, *SQLGlotCompiler.rewrites, ) + post_rewrites = (split_select_distinct_with_order_by,) supports_qualify = True diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index c095753a9777..155bf45d6584 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -14,6 +14,7 @@ from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion +from ibis.backends.sql.rewrites import split_select_distinct_with_order_by from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType @@ -26,6 +27,8 @@ class DataFusionCompiler(SQLGlotCompiler): agg = AggGen(supports_filter=True, supports_order_by=True) + post_rewrites = (split_select_distinct_with_order_by,) + UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 02cb368ce19e..5316a43c2866 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -24,6 +24,7 @@ exclude_unsupported_window_frame_from_row_number, p, replace, + split_select_distinct_with_order_by, ) from ibis.common.deferred import var @@ -69,6 +70,7 @@ class MSSQLCompiler(SQLGlotCompiler): rewrite_rows_range_order_by_window, *SQLGlotCompiler.rewrites, ) + post_rewrites = (split_select_distinct_with_order_by,) copy_func_args = True UNSUPPORTED_OPS = ( @@ -479,9 +481,11 @@ def visit_All(self, op, *, arg, where): arg = self.if_(where, arg, NULL) return sge.Min(this=arg) - def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys): + def visit_Select( + self, op, *, parent, selections, predicates, qualified, sort_keys, distinct + ): # if we've constructed a useless projection return the parent relation - if not (selections or predicates or qualified or sort_keys): + if not (selections or predicates or qualified or sort_keys or distinct): return parent result = parent @@ -500,6 +504,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke if sort_keys: result = result.order_by(*sort_keys, copy=False) + if distinct: + result = result.distinct() + return result def visit_TimestampAdd(self, op, *, left, right): diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 250abeba18dd..643b834d5fc7 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -17,6 +17,7 @@ from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres +from ibis.backends.sql.rewrites import split_select_distinct_with_order_by from ibis.common.exceptions import InvalidDecoratorError from ibis.util import gen_name @@ -41,6 +42,7 @@ class PostgresCompiler(SQLGlotCompiler): dialect = Postgres type_mapper = PostgresType + post_rewrites = (split_select_distinct_with_order_by,) agg = AggGen(supports_filter=True, supports_order_by=True) diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 8db4d1542632..1555b8bc9502 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -15,7 +15,12 @@ from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import PySparkType from ibis.backends.sql.dialects import PySpark -from ibis.backends.sql.rewrites import FirstValue, LastValue, p +from ibis.backends.sql.rewrites import ( + FirstValue, + LastValue, + p, + split_select_distinct_with_order_by, +) from ibis.common.patterns import replace from ibis.config import options from ibis.expr.operations.udf import InputType @@ -51,6 +56,7 @@ class PySparkCompiler(SQLGlotCompiler): dialect = PySpark type_mapper = PySparkType rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites) + post_rewrites = (split_select_distinct_with_order_by,) UNSUPPORTED_OPS = ( ops.RowID, diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index db3199830d48..eb2454d12279 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -23,6 +23,7 @@ from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, + split_select_distinct_with_order_by, ) from ibis.util import gen_name @@ -39,6 +40,7 @@ class TrinoCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites, ) + post_rewrites = (split_select_distinct_with_order_by,) quoted = True NAN = sg.func("nan") diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 067c9bf70b91..52f8ceba19df 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -53,6 +53,7 @@ class Select(ops.Relation): predicates: VarTuple[ops.Value[dt.Boolean]] = () qualified: VarTuple[ops.Value[dt.Boolean]] = () sort_keys: VarTuple[ops.SortKey] = () + distinct: bool = False def is_star_selection(self): return tuple(self.values.items()) == tuple(self.parent.fields.items()) @@ -128,6 +129,12 @@ def sort_to_select(_, **kwargs): return Select(_.parent, selections=_.values, sort_keys=_.keys) +@replace(p.Distinct) +def distinct_to_select(_, **kwargs): + """Convert a Distinct node to a Select node.""" + return Select(_.parent, selections=_.values, distinct=True) + + @replace(p.DropColumns) def drop_columns_to_select(_, **kwargs): """Convert a DropColumns node to a Select node.""" @@ -244,6 +251,48 @@ def merge_select_select(_, **kwargs): if _.parent.find_below(blocking, filter=ops.Value): return _ + if _.parent.distinct: + # The inner query is distinct. + # + # If the outer query is distinct, it's only safe to merge if it's a simple subselection: + # - Fusing in the presence of non-deterministic calls in the select would lead to + # incorrect results + # - Fusing in the presence of expensive calls in the select would lead to potential + # performance pitfalls + if _.distinct and not all( + isinstance(v, ops.Field) for v in _.selections.values() + ): + return _ + + # If the outer query isn't distinct, it's only safe to merge if the outer is a SELECT *: + # - If new columns are added, they might be non-distinct, changing the distinctness + # - If previous columns are removed, that would also change the distinctness + if not _.distinct and not _.is_star_selection(): + return _ + + distinct = True + elif _.distinct: + # The outer query is distinct and the inner isn't. It's only safe to merge if either + # - The inner query isn't ordered + # - The outer query is a SELECT * + # + # Otherwise we run the risk that the outer query drops columns needed for the ordering of + # the inner query - many backends don't allow select distinc queries to order by columns + # that aren't present in their selection, like + # + # SELECT DISTINCT a, b FROM t ORDER BY c --- some backends will explode at this + # + # An alternate solution would be to drop the inner ORDER BY clause, since the backend will + # ignore it anyway since it's a subquery. That feels potentially risky though, better + # to generate the SQL as written. + if _.parent.sort_keys and not _.is_star_selection(): + return _ + + distinct = True + else: + # Neither query is distinct, safe to merge + distinct = False + subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()} selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()} @@ -266,6 +315,7 @@ def merge_select_select(_, **kwargs): predicates=unique_predicates, qualified=unique_qualified, sort_keys=unique_sort_keys, + distinct=distinct, ) return result if complexity(result) <= complexity(_) else _ @@ -289,6 +339,7 @@ def sqlize( node: ops.Node, params: Mapping[ops.ScalarParameter, Any], rewrites: Sequence[Pattern] = (), + post_rewrites: Sequence[Pattern] = (), fuse_selects: bool = True, ) -> tuple[ops.Node, list[ops.Node]]: """Lower the ibis expression graph to a SQL-like relational algebra. @@ -300,7 +351,9 @@ def sqlize( params A mapping of scalar parameters to their values. rewrites - Supplementary rewrites to apply to the expression graph. + Supplementary rewrites to apply before SQL-specific transforms. + post_rewrites + Supplementary rewrites to apply after SQL-specific transforms. fuse_selects Whether to merge subsequent Select nodes into one where possible. @@ -322,6 +375,7 @@ def sqlize( | project_to_select | filter_to_select | sort_to_select + | distinct_to_select | fill_null_to_select | drop_null_to_select | drop_columns_to_select @@ -335,6 +389,9 @@ def sqlize( else: simplified = sqlized + if post_rewrites: + simplified = simplified.replace(reduce(operator.or_, post_rewrites)) + # extract common table expressions while wrapping them in a CTE node ctes = extract_ctes(simplified) @@ -351,6 +408,46 @@ def wrap(node, _, **kwargs): # supplemental rewrites selectively used on a per-backend basis +@replace(Select) +def split_select_distinct_with_order_by(_): + """Split a `SELECT DISTINCT ... ORDER BY` query when needed. + + Some databases (postgres, pyspark, ...) have issues with two types of + ordered select distinct statements: + + ``` + --- ORDER BY with an expression instead of a name in the select list + SELECT DISTINCT a, b FROM t ORDER BY a + 1 + + --- ORDER BY using a qualified column name, rather than the alias in the select list + SELECT DISTINCT a, b as x FROM t ORDER BY b --- or t.b + ``` + + We solve both these cases by splitting everything except the `ORDER BY` + into a subquery. + + ``` + SELECT DISTINCT a, b FROM t WHERE a > 10 ORDER BY a + 1 + --- is rewritten as -> + SELECT * FROM (SELECT DISTINCT a, b FROM t WHERE a > 10) ORDER BY a + 1 + ``` + """ + # risingwave and pyspark also don't allow qualified names as sort keys, like + # SELECT DISTINCT t.a FROM t ORDER BY t.a + # To avoid having specific rewrite rules for these backends to use only + # local names, we always split SELECT DISTINCT from ORDER BY here. Otherwise we + # could also avoid splitting if all sort keys appear in the select list. + if _.distinct and _.sort_keys: + inner = _.copy(sort_keys=()) + subs = {v: ops.Field(inner, k) for k, v in inner.values.items()} + sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys) + selections = { + k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items() + } + return Select(inner, selections=selections, sort_keys=sort_keys) + return _ + + @replace(p.WindowFunction(func=p.NTile(y), order_by=())) def add_order_by_to_empty_ranking_window_functions(_, **kwargs): """Add an ORDER BY clause to rank window functions that don't have one.""" diff --git a/ibis/backends/tests/sql/snapshots/test_compiler/test_column_distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_compiler/test_column_distinct/out.sql index 0913c0727447..8b092e4de69a 100644 --- a/ibis/backends/tests/sql/snapshots/test_compiler/test_column_distinct/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_compiler/test_column_distinct/out.sql @@ -1,7 +1,3 @@ SELECT DISTINCT - * -FROM ( - SELECT - "t0"."string_col" - FROM "functional_alltypes" AS "t0" -) AS "t1" \ No newline at end of file + "t0"."string_col" +FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_compiler/test_table_distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_compiler/test_table_distinct/out.sql index 098405dd6f82..2f0922df5d69 100644 --- a/ibis/backends/tests/sql/snapshots/test_compiler/test_table_distinct/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_compiler/test_table_distinct/out.sql @@ -1,8 +1,4 @@ SELECT DISTINCT - * -FROM ( - SELECT - "t0"."string_col", - "t0"."int_col" - FROM "functional_alltypes" AS "t0" -) AS "t1" \ No newline at end of file + "t0"."string_col", + "t0"."int_col" +FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter-order_by/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter-order_by/out.sql new file mode 100644 index 000000000000..542df2c4792e --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter-order_by/out.sql @@ -0,0 +1,9 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 AND "t0"."a" > 10 +ORDER BY + "t0"."a" ASC \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter/out.sql new file mode 100644 index 000000000000..c5b48291e080 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-filter/out.sql @@ -0,0 +1,7 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 AND "t0"."a" > 10 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select-distinct/out.sql new file mode 100644 index 000000000000..c31db60d3e08 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select-distinct/out.sql @@ -0,0 +1,12 @@ +SELECT DISTINCT + "t1"."a", + "t1"."b" % 2 AS "d" +FROM ( + SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" + FROM "test" AS "t0" + WHERE + "t0"."c" > 10 +) AS "t1" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select/out.sql new file mode 100644 index 000000000000..a5cf0fb0b0ac --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-non-trivial-select/out.sql @@ -0,0 +1,12 @@ +SELECT + "t1"."a", + "t1"."b" % 2 AS "d" +FROM ( + SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" + FROM "test" AS "t0" + WHERE + "t0"."c" > 10 +) AS "t1" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select-distinct/out.sql new file mode 100644 index 000000000000..5cd0def1e534 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select-distinct/out.sql @@ -0,0 +1,12 @@ +SELECT DISTINCT + "t1"."a", + "t1"."b" +FROM ( + SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" + FROM "test" AS "t0" + WHERE + "t0"."c" > 10 +) AS "t1" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select/out.sql new file mode 100644 index 000000000000..045d99bf0951 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct-select/out.sql @@ -0,0 +1,12 @@ +SELECT + "t1"."a", + "t1"."b" +FROM ( + SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" + FROM "test" AS "t0" + WHERE + "t0"."c" > 10 +) AS "t1" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct/out.sql new file mode 100644 index 000000000000..54cac377cdd1 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/distinct/out.sql @@ -0,0 +1,7 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/non-trivial-select-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/non-trivial-select-distinct/out.sql new file mode 100644 index 000000000000..2de65e2da77f --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/non-trivial-select-distinct/out.sql @@ -0,0 +1,6 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b" % 2 AS "d" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct-drop/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct-drop/out.sql new file mode 100644 index 000000000000..fabbc6d9bcb9 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct-drop/out.sql @@ -0,0 +1,14 @@ +SELECT + "t1"."b", + "t1"."c" +FROM ( + SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" + FROM "test" AS "t0" + WHERE + "t0"."c" > 10 + ORDER BY + "t0"."a" ASC +) AS "t1" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct/out.sql new file mode 100644 index 000000000000..e3ed76524e1b --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-distinct/out.sql @@ -0,0 +1,9 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b", + "t0"."c" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 +ORDER BY + "t0"."a" ASC \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-drop-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-drop-distinct/out.sql new file mode 100644 index 000000000000..fbb50382986c --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/order_by-drop-distinct/out.sql @@ -0,0 +1,8 @@ +SELECT DISTINCT + "t0"."b", + "t0"."c" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 +ORDER BY + "t0"."a" ASC \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/select-distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/select-distinct/out.sql new file mode 100644 index 000000000000..e47d47c2de40 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_fuse_distinct/select-distinct/out.sql @@ -0,0 +1,6 @@ +SELECT DISTINCT + "t0"."a", + "t0"."b" +FROM "test" AS "t0" +WHERE + "t0"."c" > 10 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/projection_distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/projection_distinct/out.sql index 098405dd6f82..2f0922df5d69 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/projection_distinct/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/projection_distinct/out.sql @@ -1,8 +1,4 @@ SELECT DISTINCT - * -FROM ( - SELECT - "t0"."string_col", - "t0"."int_col" - FROM "functional_alltypes" AS "t0" -) AS "t1" \ No newline at end of file + "t0"."string_col", + "t0"."int_col" +FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/single_column_projection_distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/single_column_projection_distinct/out.sql index 0913c0727447..8b092e4de69a 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/single_column_projection_distinct/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_distinct/single_column_projection_distinct/out.sql @@ -1,7 +1,3 @@ SELECT DISTINCT - * -FROM ( - SELECT - "t0"."string_col" - FROM "functional_alltypes" AS "t0" -) AS "t1" \ No newline at end of file + "t0"."string_col" +FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 94a52017f763..d41eaca70bf3 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -254,6 +254,67 @@ def test_projection_filter_fuse(projection_fuse_filter, snapshot): snapshot.assert_match(to_sql(expr3), "out.sql") +@pytest.mark.parametrize( + "transform", + [ + # Fused + param( + lambda t: t.distinct(), + id="distinct", + ), + param( + lambda t: t.select("a", "b").distinct(), + id="select-distinct", + ), + param( + lambda t: t.distinct().select("a", "b").distinct(), + id="distinct-select-distinct", + ), + param( + lambda t: t.distinct().filter(_.a > 10), + id="distinct-filter", + ), + param( + lambda t: t.distinct().filter(_.a > 10).order_by("a"), + id="distinct-filter-order_by", + ), + param( + lambda t: t.order_by("a").distinct(), + id="order_by-distinct", + ), + param( + lambda t: t.select("a", d=(_.b % 2)).distinct(), + id="non-trivial-select-distinct", + ), + # Not Fused + param( + lambda t: t.distinct().select("a", "b"), + id="distinct-select", + ), + param( + lambda t: t.distinct().select("a", d=(_.b % 2)), + id="distinct-non-trivial-select", + ), + param( + lambda t: t.distinct().select("a", d=(_.b % 2)).distinct(), + id="distinct-non-trivial-select-distinct", + ), + param( + lambda t: t.order_by("a").drop("a").distinct(), + id="order_by-drop-distinct", + ), + param( + lambda t: t.order_by("a").distinct().drop("a"), + id="order_by-distinct-drop", + ), + ], +) +def test_fuse_distinct(snapshot, transform): + t = ibis.table({"a": "int", "b": "int", "c": "int", "d": "int"}, name="test") + expr = transform(t.select("a", "b", "c").filter(t.c > 10)) + snapshot.assert_match(to_sql(expr), "out.sql") + + def test_bug_project_multiple_times(customer, nation, region, snapshot): # GH: 108 joined = customer.inner_join( diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index c0fefd8ae80c..8dbcefaacdc0 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -4,6 +4,7 @@ import datetime import decimal from collections import Counter +from itertools import permutations from operator import invert, methodcaller, neg import pytest @@ -1395,6 +1396,64 @@ def test_pivot_wider(backend): assert len(df) == diamonds.color.nunique().execute() +def test_select_distinct_order_by(backend, alltypes, df): + res = alltypes.select("int_col").distinct().order_by("int_col").to_pandas() + sol = df[["int_col"]].drop_duplicates().sort_values("int_col") + backend.assert_frame_equal(res, sol) + + +def test_select_distinct_order_by_alias(backend, con): + df = pd.DataFrame({"x": [1, 2, 3, 3], "y": [10, 9, 8, 8]}) + expr = ibis.memtable(df).select(y="x", x="y").distinct().order_by("x", "y") + sol = ( + df.drop_duplicates() + .rename(columns={"x": "y", "y": "x"}) + .sort_values(["x", "y"]) + ) + res = con.to_pandas(expr) + backend.assert_frame_equal(res, sol) + + +def test_select_distinct_order_by_expr(backend, alltypes, df): + res = alltypes.select("int_col").distinct().order_by(-_.int_col).to_pandas() + sol = df[["int_col"]].drop_duplicates().sort_values("int_col", ascending=False) + backend.assert_frame_equal(res, sol) + + +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], + reason="We don't fuse these ops yet for non-SQL backends", + strict=False, +) +@pytest.mark.parametrize( + "ops", + [ + param(ops, id="-".join(ops)) + for ops in permutations(("select", "distinct", "filter", "order_by")) + if ops.index("select") < ops.index("distinct") + ], +) +def test_select_distinct_filter_order_by_commute(backend, alltypes, df, ops): + """For simple versions of these ops, the order in which they're called + doesn't matter, they're all handled in a commutative way.""" + expr = alltypes.select("int_col", "float_col", b=alltypes.id % 33) + for op in ops: + if op == "select": + expr = expr.select("int_col", "b") + elif op == "distinct": + expr = expr.distinct() + elif op == "filter": + expr = expr.filter(expr.int_col > 5) + elif op == "order_by": + expr = expr.order_by(-expr.int_col, expr.b) + + sol = df.assign(b=df.id % 33)[["int_col", "b"]] + sol = sol[sol.int_col > 5].drop_duplicates() + sol = sol.set_index([-sol.int_col, sol.b]).sort_index().reset_index(drop=True) + res = expr.to_pandas() + backend.assert_frame_equal(res, sol) + + @pytest.mark.parametrize( "on", [