Skip to content

Commit

Permalink
unnest scalar subqueries as cross joins (#761)
Browse files Browse the repository at this point in the history
fixes #748
  • Loading branch information
tobymao committed Nov 25, 2022
1 parent 4851ce1 commit 4373ad8
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 92 deletions.
11 changes: 8 additions & 3 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def aggregate(self, step, context):
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
condition = self.generate(step.condition)

def add_row():
if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))

for i in range(length):
context.set_index(i)
Expand All @@ -296,12 +301,12 @@ def aggregate(self, step, context):
end += 1
if key != group:
context.set_range(start, end - 2)
table.append(group + context.eval_tuple(aggregations))
add_row()
group = key
start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
table.append(group + context.eval_tuple(aggregations))
add_row()

context = self.context({step.name: table, **{name: table for name in context.tables}})

Expand Down Expand Up @@ -400,7 +405,7 @@ class Generator(generator.Generator):
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
Expand Down
36 changes: 28 additions & 8 deletions sqlglot/optimizer/unnest_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import itertools

from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import ScopeType, traverse_scope


def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Convert scalar subqueries into cross joins.
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
Expand All @@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)

return expression


def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if len(select.selects) > 1:
return

predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)

if not predicate or parent_select is not predicate.parent_select:
return

if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
# this subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)

parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
return

if select.find(exp.Limit, exp.Offset):
return

if isinstance(predicate, exp.Any):
Expand All @@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):

column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)

on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
Expand Down
36 changes: 21 additions & 15 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,20 @@ def from_expression(
aggregations = []
sequence = itertools.count()

for e in expression.expressions:
aggs = list(e.find_all(exp.AggFunc))
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))

if aggs:
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for agg in aggs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
extract_agg_operands(e)
else:
projections.append(e)

Expand All @@ -157,6 +158,13 @@ def from_expression(
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name

having = expression.args.get("having")

if having:
extract_agg_operands(having)
aggregate.condition = having.this

aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
Expand All @@ -173,11 +181,6 @@ def from_expression(
aggregate.add_dependency(step)
step = aggregate

having = expression.args.get("having")

if having:
step.condition = having.this

order = expression.args.get("order")

if order:
Expand Down Expand Up @@ -339,6 +342,9 @@ def _to_s(self, indent: str) -> t.List[str]:
lines.append(f"{indent}Group:")
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.condition:
lines.append(f"{indent}Having:")
lines.append(f"{indent} - {self.condition.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:
Expand Down
57 changes: 30 additions & 27 deletions tests/fixtures/optimizer/tpc-h/tpc-h.sql
Original file line number Diff line number Diff line change
Expand Up @@ -666,27 +666,28 @@ WITH "supplier_2" AS (
FROM "nation" AS "nation"
WHERE
"nation"."n_name" = 'GERMANY'
), "_u_0" AS (
SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
)
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
CROSS JOIN "_u_0" AS "_u_0"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
)
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0")
ORDER BY
"value" DESC;

Expand Down Expand Up @@ -880,6 +881,10 @@ WITH "revenue" AS (
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY
"lineitem"."l_suppkey"
), "_u_0" AS (
SELECT
MAX("revenue"."total_revenue") AS "_col_0"
FROM "revenue"
)
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
Expand All @@ -889,12 +894,9 @@ SELECT
"revenue"."total_revenue" AS "total_revenue"
FROM "supplier" AS "supplier"
JOIN "revenue"
ON "revenue"."total_revenue" = (
SELECT
MAX("revenue"."total_revenue") AS "_col_0"
FROM "revenue"
)
AND "supplier"."s_suppkey" = "revenue"."supplier_no"
ON "supplier"."s_suppkey" = "revenue"."supplier_no"
JOIN "_u_0" AS "_u_0"
ON "revenue"."total_revenue" = "_u_0"."_col_0"
ORDER BY
"s_suppkey";

Expand Down Expand Up @@ -1395,7 +1397,14 @@ order by
cntrycode;
WITH "_u_0" AS (
SELECT
"orders"."o_custkey" AS "_u_1"
AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
), "_u_1" AS (
SELECT
"orders"."o_custkey" AS "_u_2"
FROM "orders" AS "orders"
GROUP BY
"orders"."o_custkey"
Expand All @@ -1405,18 +1414,12 @@ SELECT
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey"
JOIN "_u_0" AS "_u_0"
ON "customer"."c_acctbal" > "_u_0"."_col_0"
LEFT JOIN "_u_1" AS "_u_1"
ON "_u_1"."_u_2" = "customer"."c_custkey"
WHERE
"_u_0"."_u_1" IS NULL
AND "customer"."c_acctbal" > (
SELECT
AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
)
"_u_1"."_u_2" IS NULL
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
GROUP BY
SUBSTRING("customer"."c_phone", 1, 2)
Expand Down
Loading

1 comment on commit 4373ad8

@georgesittas
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Please sign in to comment.