Skip to content

Commit

Permalink
UDTF scope refactor (#1181)
Browse files Browse the repository at this point in the history
* UDTF scope refactor

* fixup

* extend laterals

* fixup
  • Loading branch information
barakalon authored Feb 15, 2023
1 parent 6860bc9 commit d95317e
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 78 deletions.
6 changes: 2 additions & 4 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)

# Now append the rest
for scope in itertools.chain(
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
):
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
Expand All @@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)

if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)

if scope.is_cte:
Expand Down
5 changes: 2 additions & 3 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table.unnest(), exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
Expand Down Expand Up @@ -396,7 +394,8 @@ def get_source_columns(self, name, only_visible=False):
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
k: self.get_source_columns(k)
for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
}
return self._source_columns

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
next_name = lambda: f"_q_{next(sequence)}"

for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
Expand Down
169 changes: 105 additions & 64 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class Scope:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
Expand All @@ -34,8 +38,10 @@ class Scope:
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes = (list[Scope]) List of all child scopes for CTEs
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
cte_scopes (list[Scope]): List of all child scopes for CTEs
derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
Expand All @@ -47,22 +53,28 @@ def __init__(
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
):
self.expression = expression
self.sources = sources or {}
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
self.sources.update(self.lateral_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.derived_table_scopes = []
self.table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.udtf_scopes = []
self.clear_cache()

def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
self._udtfs = None
self._tables = None
self._ctes = None
self._subqueries = None
Expand All @@ -86,6 +98,7 @@ def _collect(self):
self._ctes = []
self._subqueries = []
self._derived_tables = []
self._udtfs = []
self._raw_columns = []
self._join_hints = []

Expand All @@ -99,7 +112,7 @@ def _collect(self):
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
Expand Down Expand Up @@ -199,6 +212,17 @@ def derived_tables(self):
self._ensure_collected()
return self._derived_tables

@property
def udtfs(self):
"""
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
"""
self._ensure_collected()
return self._udtfs

@property
def subqueries(self):
"""
Expand Down Expand Up @@ -227,7 +251,9 @@ def columns(self):
columns = self._raw_columns

external_columns = [
column for scope in self.subquery_scopes for column in scope.external_columns
column
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
for column in scope.external_columns
]

named_selects = set(self.expression.named_selects)
Expand Down Expand Up @@ -262,9 +288,8 @@ def selected_sources(self):

for table in self.tables:
referenced_names.append((table.alias_or_name, table))
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))

for expression in itertools.chain(self.derived_tables, self.udtfs):
referenced_names.append((expression.alias, expression.unnest()))
result = {}

for name, node in referenced_names:
Expand Down Expand Up @@ -414,7 +439,7 @@ def traverse(self):
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
Expand Down Expand Up @@ -480,24 +505,23 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF):
_set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope


def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
yield from _traverse_ctes(scope)
yield from _traverse_tables(scope)
yield from _traverse_subqueries(scope)
_add_table_sources(scope)


def _traverse_union(scope):
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
yield from _traverse_ctes(scope)

# The last scope to be yield should be the top most scope
left = None
Expand All @@ -511,82 +535,98 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]


def _set_udtf_scope(scope):
parent = scope.expression.parent
from_ = parent.args.get("from")

if not from_:
return

for table in from_.expressions:
if isinstance(table, exp.Table):
scope.tables.append(table)
elif isinstance(table, exp.Subquery):
scope.subqueries.append(table)
_add_table_sources(scope)
_traverse_subqueries(scope)


def _traverse_derived_tables(derived_tables, scope, scope_type):
def _traverse_ctes(scope):
sources = {}
is_cte = scope_type == ScopeType.CTE

for derived_table in derived_tables:
for cte in scope.ctes:
recursive_scope = None

# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this
if scope.expression.args["with"].recursive:
union = cte.this

if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)

for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
cte.this,
chain_sources=sources,
outer_column_list=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
yield child_scope

# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
alias = derived_table.alias
alias = cte.alias
sources[alias] = child_scope

if recursive_scope:
child_scope.add_source(alias, recursive_scope)

# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else:
scope.derived_table_scopes.append(child_scope)
scope.cte_scopes.append(child_scope)

scope.sources.update(sources)


def _add_table_sources(scope):
def _traverse_tables(scope):
sources = {}
for table in scope.tables:
table_name = table.name

if table.alias:
source_name = table.alias
else:
source_name = table_name
# Traverse FROMs, JOINs, and LATERALs in the order they are defined
expressions = []
from_ = scope.expression.args.get("from")
if from_:
expressions.extend(from_.expressions)

if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)

expressions.extend(scope.expression.args.get("laterals") or [])

for expression in expressions:
if isinstance(expression, exp.Table):
table_name = expression.name
source_name = expression.alias_or_name

if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
else:
sources[source_name] = expression
continue

if isinstance(expression, exp.UDTF):
lateral_sources = sources
scope_type = ScopeType.UDTF
scopes = scope.udtf_scopes
else:
sources[source_name] = table
lateral_sources = None
scope_type = ScopeType.DERIVED_TABLE
scopes = scope.derived_table_scopes

for child_scope in _traverse_scope(
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_column_list=expression.alias_column_names,
scope_type=scope_type,
)
):
yield child_scope

# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
alias = expression.alias
sources[alias] = child_scope

# append the final child_scope yielded
scopes.append(child_scope)
scope.table_scopes.append(child_scope)

scope.sources.update(sources)

Expand Down Expand Up @@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True):

if node is expression:
continue
elif isinstance(node, exp.CTE):
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
prune = True
elif isinstance(node, exp.Subqueryable):
if (
isinstance(node, exp.CTE)
or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
prune = True
7 changes: 1 addition & 6 deletions tests/fixtures/optimizer/optimizer.sql
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
# title: lateral
# execute: false
SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
WITH "z_2" AS (
SELECT
"z"."a" AS "a"
FROM "z" AS "z"
)
SELECT
"z"."a" AS "a",
"q"."m" AS "m"
FROM "z_2" AS "z"
FROM "z" AS "z"
LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m";

Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa);
SELECT aa FROM x, UNNEST(a) AS aa;
SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa;

# execute: false
# dialect: presto
SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(b, ',')) AS i(b);
SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(x.b, ',')) AS i(b);

# execute: false
SELECT c FROM (SELECT 1 a) AS x LATERAL VIEW EXPLODE(a) AS c;
SELECT _q_0.c AS c FROM (SELECT 1 AS a) AS x LATERAL VIEW EXPLODE(x.a) _q_0 AS c;

--------------------------------------
-- Window functions
--------------------------------------
Expand Down

0 comments on commit d95317e

Please sign in to comment.