From bec36391d85152fa478222403d06beffa8d6ddfb Mon Sep 17 00:00:00 2001 From: Toby Mao Date: Fri, 30 Dec 2022 09:30:28 -0800 Subject: [PATCH] recursive cte scope fixes #856 (#860) --- sqlglot/expressions.py | 4 +++ sqlglot/optimizer/eliminate_subqueries.py | 4 ++- sqlglot/optimizer/pushdown_predicates.py | 3 ++ sqlglot/optimizer/scope.py | 30 ++++++++++++---- sqlglot/parser.py | 8 +++-- tests/fixtures/optimizer/optimizer.sql | 43 +++++++++++++++++++++++ tests/test_optimizer.py | 26 ++++++++++++++ tests/test_parser.py | 3 ++ 8 files changed, 112 insertions(+), 9 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b7d46752ca..cc69eb183b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -686,6 +686,10 @@ class CharacterSet(Expression): class With(Expression): arg_types = {"expressions": True, "recursive": False} + @property + def recursive(self) -> bool: + return bool(self.args.get("recursive")) + class WithinGroup(Expression): arg_types = {"this": True, "expression": False} diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 39e252c790..2245cc2761 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -58,7 +58,9 @@ def eliminate_subqueries(expression): existing_ctes = {} with_ = root.expression.args.get("with") + recursive = False if with_: + recursive = with_.args.get("recursive") for cte in with_.expressions: existing_ctes[cte.this] = cte.alias new_ctes = [] @@ -88,7 +90,7 @@ def eliminate_subqueries(expression): new_ctes.append(new_cte) if new_ctes: - expression.set("with", exp.With(expressions=new_ctes)) + expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f92e5c3457..a9cd45fc53 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -148,6 +148,9 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): + with_ = source.parent.expression.args.get("with") + if with_ and with_.recursive: + return {} node = source.expression if isinstance(node, exp.Join): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 18848f393d..6125e4e61d 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -511,9 +511,20 @@ def _traverse_union(scope): def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} + is_cte = scope_type == ScopeType.CTE for derived_table in derived_tables: - top = None + recursive_scope = None + + # if the scope is a recursive cte, it must be in the form of + # base_case UNION recursive. thus the recursive scope is the first + # section of the union. + if is_cte and scope.expression.args["with"].recursive: + union = derived_table.this + + if isinstance(union, exp.Union): + recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, @@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope - top = child_scope + # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - sources[derived_table.alias] = child_scope - if scope_type == ScopeType.CTE: - scope.cte_scopes.append(top) + alias = derived_table.alias + sources[alias] = child_scope + + if recursive_scope: + child_scope.add_source(alias, recursive_scope) + + # append the final child_scope yielded + if is_cte: + scope.cte_scopes.append(child_scope) else: - scope.derived_table_scopes.append(top) + scope.derived_table_scopes.append(child_scope) + scope.sources.update(sources) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 67690118ce..308f36385e 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1176,8 +1176,12 @@ def _parse_select(self, nested=False, table=False): elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) + this = self._parse_set_operations(this) self._match_r_paren() - this = self._parse_subquery(this) + # early return so that subquery unions aren't parsed again + # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 + # Union ALL should be a property of the top select node, not the subquery + return self._parse_subquery(this) elif self._match(TokenType.VALUES): if self._curr.token_type == TokenType.L_PAREN: # We don't consume the left paren because it's consumed in _parse_value @@ -1197,7 +1201,7 @@ def _parse_select(self, nested=False, table=False): else: this = None - return self._parse_set_operations(this) if this else None + return self._parse_set_operations(this) def _parse_with(self, skip_with_token=False): if not skip_with_token and not self._match(TokenType.WITH): diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a692c7dd0d..79c2f1ec35 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -291,3 +291,46 @@ SELECT a1 FROM cte1; SELECT "x"."a" AS "a1" FROM "x" AS "x"; + +# title: recursive cte +WITH RECURSIVE cte1 AS ( + SELECT * + FROM ( + SELECT 1 AS a, 2 AS b + ) base + CROSS JOIN (SELECT 3 c) y + UNION ALL + SELECT * + FROM cte1 + WHERE a < 1 +) +SELECT * +FROM cte1; +WITH RECURSIVE "base" AS ( + SELECT + 1 AS "a", + 2 AS "b" +), "y" AS ( + SELECT + 3 AS "c" +), "cte1" AS ( + SELECT + "base"."a" AS "a", + "base"."b" AS "b", + "y"."c" AS "c" + FROM "base" AS "base" + CROSS JOIN "y" AS "y" + UNION ALL + SELECT + "cte1"."a" AS "a", + "cte1"."b" AS "b", + "cte1"."c" AS "c" + FROM "cte1" + WHERE + "cte1"."a" < 1 +) +SELECT + "cte1"."a" AS "a", + "cte1"."b" AS "b", + "cte1"."c" AS "c" +FROM "cte1"; diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0c5f6cd241..fb2b289ab6 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -555,3 +555,29 @@ def test_aggfunc_annotation(self): parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema ) self.assertEqual(expression.expressions[0].type.this, target_type) + + def test_recursive_cte(self): + query = parse_one( + """ + with recursive t(n) AS + ( + select 1 + union all + select n + 1 + FROM t + where n < 3 + ), y AS ( + select n + FROM t + union all + select n + 1 + FROM y + where n < 2 + ) + select * from y + """ + ) + + scope_t, scope_y = build_scope(query).cte_scopes + self.assertEqual(set(scope_t.cte_sources), {"t"}) + self.assertEqual(set(scope_y.cte_sources), {"t", "y"}) diff --git a/tests/test_parser.py b/tests/test_parser.py index 0be15e487f..ae2e4cdc20 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -76,6 +76,9 @@ def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) + def test_union_order(self): + self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union) + def test_select(self): self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])