diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 8704e9047a..39e252c790 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -68,6 +68,9 @@ def eliminate_subqueries(expression): for cte_scope in root.cte_scopes: # Append all the new CTEs from this existing CTE for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue new_cte = _eliminate(scope, existing_ctes, taken) if new_cte: new_ctes.append(new_cte) @@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + def _eliminate_union(scope, existing_ctes, taken): duplicate_cte_alias = existing_ctes.get(scope.expression) @@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + table = exp.alias_(exp.table_(name), alias=parent.alias or name) + parent.replace(table) + + return cte + + +def _eliminate_cte(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + table.replace(new_table) + + return cte + + +def _new_cte(scope, existing_ctes, taken): + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ duplicate_cte_alias = existing_ctes.get(scope.expression) parent = scope.expression.parent - name = alias = parent.alias + name = parent.alias - if not alias: - name = alias = find_new_name(taken=taken, base="cte") + if not name: + name = find_new_name(taken=taken, base="cte") if duplicate_cte_alias: name = duplicate_cte_alias - elif taken.get(alias): - name = find_new_name(taken=taken, base=alias) + elif taken.get(name): + name = find_new_name(taken=taken, base=name) taken[name] = scope - table = exp.alias_(exp.table_(name), alias=alias) - parent.replace(table) - if not duplicate_cte_alias: existing_ctes[scope.expression] = name - return exp.CTE( + cte = exp.CTE( this=scope.expression, alias=exp.TableAlias(this=exp.to_identifier(name)), ) + else: + cte = None + return name, cte diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index f395c0a9bf..c566657299 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x -- Existing duplicate CTE WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z; WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z; + +-- Nested CTE +WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2); +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte; + +-- Nested CTE inside CTE +WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; +WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; + +-- Duplicate CTE nested in CTE +WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2; +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a1e531be60..a692c7dd0d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3), FROM `x` AS `x` JOIN `y` AS `y` ON `x`.`b` = `y`.`b`; + +WITH cte1 AS ( + WITH cte2 AS ( + SELECT a, b FROM x + ) + SELECT a1 + FROM ( + WITH cte3 AS (SELECT 1) + SELECT a AS a1, b AS b1 FROM cte2 + ) +) +SELECT a1 FROM cte1; +SELECT + "x"."a" AS "a1" +FROM "x" AS "x";