Skip to content

Commit

Permalink
Feat(optimizer): optimize pivots (#1617)
Browse files Browse the repository at this point in the history
* Feat(optimizer): optimize pivots

* Fixup

* Simplify

* Cleanup

* Fix pivot sql generation

* Fixed snowflake pivot column names, add another optimizer test

* Fixed issue with pivoted cte source, added bigquery test

* Factor out some computations

* Cleanup

* Add transform to unalias pivot in spark, more tests

* Typo

* Comment fixup
  • Loading branch information
georgesittas authored May 16, 2023
1 parent 409f13d commit 4b1aa02
Show file tree
Hide file tree
Showing 19 changed files with 268 additions and 49 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:

def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
return ""


def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class Snowflake(Dialect):
}

class Parser(parser.Parser):
QUOTED_PIVOT_COLUMNS = True
IDENTIFY_PIVOT_STRINGS = True

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
49 changes: 48 additions & 1 deletion sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
raise ValueError("Improper scale for timestamp")


def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
"""
Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
pivoted source in a subquery with the same alias to preserve the query's semantics.
Example:
>>> from sqlglot import parse_one
>>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
>>> print(_unalias_pivot(expr).sql(dialect="spark"))
SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
"""
if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
pivot = expression.this.args["pivots"][0]
if pivot.alias:
alias = pivot.args["alias"].pop()
return exp.From(
this=expression.this.replace(
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
)
)

return expression


def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
"""
Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
so we need to unqualify it.
Example:
>>> from sqlglot import parse_one
>>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
>>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
"""
if isinstance(expression, exp.Pivot):
expression.args["field"].transform(
lambda node: exp.column(node.output_name, quoted=node.this.quoted)
if isinstance(node, exp.Column)
else node,
copy=False,
)

return expression


class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
Expand Down Expand Up @@ -188,11 +234,12 @@ class Generator(Hive.Generator):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
Expand Down
9 changes: 4 additions & 5 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,10 +2037,10 @@ def subquery(self, alias=None, copy=True) -> Subquery:
Alias: the subquery
"""
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
if not isinstance(alias, Expression):
alias = TableAlias(this=to_identifier(alias)) if alias else None

return Subquery(this=instance, alias=alias)

def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
raise NotImplementedError
Expand Down Expand Up @@ -2964,7 +2964,6 @@ class Tag(Expression):

class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
"field": True,
Expand Down
19 changes: 8 additions & 11 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,9 +1176,10 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:

alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = self.expressions(expression, key="hints", flat=True)
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep="")
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="")
laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
Expand Down Expand Up @@ -1217,14 +1218,13 @@ def tablesample_sql(
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"

def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
expressions = self.expressions(expression, flat=True)
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field}){alias}"
return f"{direction}({expressions} FOR {field}){alias}"

def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
Expand Down Expand Up @@ -1582,13 +1582,10 @@ def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""

sql = self.query_modifiers(
expression,
self.wrap(expression),
alias,
self.expressions(expression, key="pivots", sep=" "),
)
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""

sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)

def qualify_sql(self, expression: exp.Qualify) -> str:
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def _eliminate_union(scope, existing_ctes, taken):


def _eliminate_derived_table(scope, existing_ctes, taken):
# This ensures we don't drop the "pivot" arg from a pivoted subquery
if scope.parent.pivots:
return None

parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)

Expand Down
1 change: 1 addition & 0 deletions sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _outer_select_joins_on_inner_select_join():
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def optimize(
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)

for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)

return expression
5 changes: 3 additions & 2 deletions sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})

if scope.expression.args.get("distinct"):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}

if isinstance(scope.expression, exp.Union):
Expand Down
41 changes: 39 additions & 2 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema

Expand Down Expand Up @@ -65,7 +66,7 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
raise OptimizeError(
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
Expand Down Expand Up @@ -249,6 +250,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")

if not column_table:
if scope.pivots and not column.find_ancestor(exp.Pivot):
# If the column is under the Pivot expression, we need to qualify it
# using the name of the pivoted source instead of the pivot's alias
column.set("table", exp.to_identifier(scope.pivots[0].alias))
continue

column_table = resolver.get_table(column_name)

# column_table can be a '' because bigquery unnest has no table alias
Expand All @@ -272,6 +279,13 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))

for pivot in scope.pivots:
for column in pivot.find_all(exp.Column):
if not column.table and column.name in resolver.all_columns:
column_table = resolver.get_table(column.name)
if column_table:
column.set("table", column_table)


def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
Expand All @@ -281,6 +295,19 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()

# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)

has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))

pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]

for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
Expand All @@ -297,9 +324,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")

columns = resolver.get_source_columns(table, only_visible=True)

if columns and "*" not in columns:
if has_pivoted_source:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
)
continue

table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
Expand All @@ -319,12 +355,13 @@ def _expand_stars(scope, resolver, using_column_tables):
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
column = exp.column(name, table=table)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
else:
return

scope.expression.set("expressions", new_selections)


Expand Down
12 changes: 10 additions & 2 deletions sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))

if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
alias_ = next_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)

pivots = derived_table.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))

for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
Expand All @@ -60,12 +64,16 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source,
name if name else next_name(),
name or source.name or next_name(),
copy=True,
table=True,
)
)

pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))

if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
Expand Down
21 changes: 19 additions & 2 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def clear_cache(self):
self._columns = None
self._external_columns = None
self._join_hints = None
self._pivots = None

def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
Expand Down Expand Up @@ -372,6 +373,17 @@ def join_hints(self):
return []
return self._join_hints

@property
def pivots(self):
if not self._pivots:
self._pivots = [
pivot
for node in self.tables + self.derived_tables
for pivot in node.args.get("pivots") or []
]

return self._pivots

def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Expand Down Expand Up @@ -603,8 +615,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name

if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")
if pivots:
sources[pivots[0].alias] = expression
else:
sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
Expand Down
Loading

0 comments on commit 4b1aa02

Please sign in to comment.