From acd7d82ec3778b593c4aadfba6f62dd0119c7d28 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:05:05 -0400 Subject: [PATCH] fix(internals): ensure that CTEs are emitted in topological order (#9726) Co-authored-by: Jim Crist-Harif --- ibis/backends/sql/rewrites.py | 10 +++++----- ibis/backends/tests/sql/test_sql.py | 14 ++++++++++++++ ibis/common/graph.py | 5 +++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 584cf748bbcd..511052309305 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -249,17 +249,17 @@ def merge_select_select(_, **kwargs): return result if complexity(result) <= complexity(_) else _ -def extract_ctes(node): - result = [] +def extract_ctes(node: ops.Relation) -> set[ops.Relation]: cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample) dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar) g = Graph.from_bfs(node, filter=~InstanceOf(dont_count)) + result = set() for op, dependents in g.invert().items(): if isinstance(op, ops.View) or ( len(dependents) > 1 and isinstance(op, cte_types) ): - result.append(op) + result.add(op) return result @@ -315,14 +315,14 @@ def sqlize( simplified = sqlized # extract common table expressions while wrapping them in a CTE node - ctes = frozenset(extract_ctes(simplified)) + ctes = extract_ctes(simplified) def wrap(node, _, **kwargs): new = node.__recreate__(kwargs) return CTE(new) if node in ctes else new result = simplified.replace(wrap) - ctes = reversed([cte.parent for cte in result.find(CTE)]) + ctes = [cte.parent for cte in result.find(CTE, ordered=True)] return result, ctes diff --git a/ibis/backends/tests/sql/test_sql.py b/ibis/backends/tests/sql/test_sql.py index 991cd02b7f3e..00aac59d5771 100644 --- a/ibis/backends/tests/sql/test_sql.py +++ b/ibis/backends/tests/sql/test_sql.py @@ -590,3 +590,17 @@ def test_no_cartesian_join(snapshot): ] ) snapshot.assert_match(ibis.to_sql(final, dialect="duckdb"), "out.sql") + + +def test_ctes_in_order(): + table1 = ibis.table({"id": "int"}, name="table1") + table2 = ibis.table({"id": "int"}, name="table2") + table3 = ibis.table({"id": "int"}, name="table3") + + ids_table = table1.union(table2).alias("first") + info_table = ids_table.union(table3).alias("second") + + expr = ids_table.union(info_table) + + sql = ibis.to_sql(expr, dialect="duckdb") + assert sql.find('"first" AS (') < sql.find('"second" AS (') diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 184dbaa0490b..3835b7b51c16 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -338,6 +338,7 @@ def find( finder: FinderLike, filter: Optional[FinderLike] = None, context: Optional[dict] = None, + ordered: bool = False, ) -> list[Node]: """Find all nodes matching a given pattern or type in the graph. @@ -355,6 +356,8 @@ def find( the given filter and stop otherwise. context Optional context to use if `finder` or `filter` is a pattern. + ordered + Emit nodes in topological order if `True`. Returns ------- @@ -364,6 +367,8 @@ def find( """ graph = Graph.from_bfs(self, filter=filter, context=context) finder = _coerce_finder(finder, context) + if ordered: + graph, _ = graph.toposort() return [node for node in graph.nodes() if finder(node)] @experimental