Skip to content

Commit

Permalink
recursive cte scope fixes #856 (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored Dec 30, 2022
1 parent a3503fb commit bec3639
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 9 deletions.
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,10 @@ class CharacterSet(Expression):
class With(Expression):
arg_types = {"expressions": True, "recursive": False}

@property
def recursive(self) -> bool:
return bool(self.args.get("recursive"))


class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}

with_ = root.expression.args.get("with")
recursive = False
if with_:
recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
Expand Down Expand Up @@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)

if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))

return expression

Expand Down
3 changes: 3 additions & 0 deletions sqlglot/optimizer/pushdown_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):

# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression

if isinstance(node, exp.Join):
Expand Down
30 changes: 24 additions & 6 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,20 @@ def _traverse_union(scope):

def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE

for derived_table in derived_tables:
top = None
recursive_scope = None

# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this

if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)

for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
Expand All @@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope

# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
alias = derived_table.alias
sources[alias] = child_scope

if recursive_scope:
child_scope.add_source(alias, recursive_scope)

# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else:
scope.derived_table_scopes.append(top)
scope.derived_table_scopes.append(child_scope)

scope.sources.update(sources)


Expand Down
8 changes: 6 additions & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,8 +1176,12 @@ def _parse_select(self, nested=False, table=False):
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren()
this = self._parse_subquery(this)
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
Expand All @@ -1197,7 +1201,7 @@ def _parse_select(self, nested=False, table=False):
else:
this = None

return self._parse_set_operations(this) if this else None
return self._parse_set_operations(this)

def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH):
Expand Down
43 changes: 43 additions & 0 deletions tests/fixtures/optimizer/optimizer.sql
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,46 @@ SELECT a1 FROM cte1;
SELECT
"x"."a" AS "a1"
FROM "x" AS "x";

# title: recursive cte
WITH RECURSIVE cte1 AS (
SELECT *
FROM (
SELECT 1 AS a, 2 AS b
) base
CROSS JOIN (SELECT 3 c) y
UNION ALL
SELECT *
FROM cte1
WHERE a < 1
)
SELECT *
FROM cte1;
WITH RECURSIVE "base" AS (
SELECT
1 AS "a",
2 AS "b"
), "y" AS (
SELECT
3 AS "c"
), "cte1" AS (
SELECT
"base"."a" AS "a",
"base"."b" AS "b",
"y"."c" AS "c"
FROM "base" AS "base"
CROSS JOIN "y" AS "y"
UNION ALL
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1"
WHERE
"cte1"."a" < 1
)
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1";
26 changes: 26 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,29 @@ def test_aggfunc_annotation(self):
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
)
self.assertEqual(expression.expressions[0].type.this, target_type)

def test_recursive_cte(self):
query = parse_one(
"""
with recursive t(n) AS
(
select 1
union all
select n + 1
FROM t
where n < 3
), y AS (
select n
FROM t
union all
select n + 1
FROM y
where n < 2
)
select * from y
"""
)

scope_t, scope_y = build_scope(query).cte_scopes
self.assertEqual(set(scope_t.cte_sources), {"t"})
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
3 changes: 3 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"])

def test_union_order(self):
self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union)

def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
Expand Down

0 comments on commit bec3639

Please sign in to comment.