From 4d4fe71a7671479cf07283ea8014d34788fe7e46 Mon Sep 17 00:00:00 2001 From: tobymao Date: Thu, 24 Nov 2022 23:09:38 -0800 Subject: [PATCH] unnest scalar subqueries as cross joins fixes #748 --- sqlglot/executor/python.py | 11 ++- sqlglot/optimizer/unnest_subqueries.py | 36 ++++++-- sqlglot/planner.py | 36 ++++---- tests/fixtures/optimizer/tpc-h/tpc-h.sql | 57 +++++++------ .../fixtures/optimizer/unnest_subqueries.sql | 84 ++++++++++--------- tests/test_executor.py | 24 +++++- 6 files changed, 156 insertions(+), 92 deletions(-) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index b0fe6d1fef..9ebef24dac 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -288,6 +288,11 @@ def aggregate(self, step, context): end = 1 length = len(context.table) table = self.table(list(step.group) + step.aggregations) + condition = self.generate(step.condition) + + def add_row(): + if not condition or context.eval(condition): + table.append(group + context.eval_tuple(aggregations)) for i in range(length): context.set_index(i) @@ -296,12 +301,12 @@ def aggregate(self, step, context): end += 1 if key != group: context.set_range(start, end - 2) - table.append(group + context.eval_tuple(aggregations)) + add_row() group = key start = end - 2 if i == length - 1: context.set_range(start, end - 1) - table.append(group + context.eval_tuple(aggregations)) + add_row() context = self.context({step.name: table, **{name: table for name in context.tables}}) @@ -400,7 +405,7 @@ class Generator(generator.Generator): exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", - exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", + exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.Is: lambda self, e: self.binary(e, "is"), exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index dbd680bd00..2046917ad6 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,16 +1,15 @@ import itertools from sqlglot import exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import ScopeType, traverse_scope def unnest_subqueries(expression): """ Rewrite sqlglot AST to convert some predicates with subqueries into joins. - Convert the subquery into a group by so it is not a many to many left join. - Unnesting can only occur if the subquery does not have LIMIT or OFFSET. - Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + Convert scalar subqueries into cross joins. + Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. Example: >>> import sqlglot @@ -29,21 +28,43 @@ def unnest_subqueries(expression): for scope in traverse_scope(expression): select = scope.expression parent = select.parent_select + if not parent: + continue if scope.external_columns: decorrelate(select, parent, scope.external_columns, sequence) - else: + elif scope.scope_type == ScopeType.SUBQUERY: unnest(select, parent, sequence) return expression def unnest(select, parent_select, sequence): - predicate = select.find_ancestor(exp.In, exp.Any) + if len(select.selects) > 1: + return + + predicate = select.find_ancestor(exp.Condition) + alias = _alias(sequence) if not predicate or parent_select is not predicate.parent_select: return - if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + # this subquery returns a scalar and can just be converted to a cross join + if not isinstance(predicate, (exp.In, exp.Any)): + having = predicate.find_ancestor(exp.Having) + column = exp.column(select.selects[0].alias_or_name, alias) + if having and having.parent_select is parent_select: + column = exp.Max(this=column) + _replace(select.parent, column) + + parent_select.join( + select, + join_type="CROSS", + join_alias=alias, + copy=False, + ) + return + + if select.find(exp.Limit, exp.Offset): return if isinstance(predicate, exp.Any): @@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence): column = _other_operand(predicate) value = select.selects[0] - alias = _alias(sequence) on = exp.condition(f'{column} = "{alias}"."{value.alias}"') _replace(predicate, f"NOT {on.right} IS NULL") diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 3e96ea5707..a036b7c277 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -130,19 +130,20 @@ def from_expression( aggregations = [] sequence = itertools.count() - for e in expression.expressions: - aggs = list(e.find_all(exp.AggFunc)) + def extract_agg_operands(expression): + for agg in expression.find_all(exp.AggFunc): + for operand in agg.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = f"_a_{next(sequence)}" + operand.replace(exp.column(operands[operand], quoted=True)) - if aggs: + for e in expression.expressions: + if e.find(exp.AggFunc): projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) aggregations.append(e) - for agg in aggs: - for operand in agg.unnest_operands(): - if isinstance(operand, exp.Column): - continue - if operand not in operands: - operands[operand] = f"_a_{next(sequence)}" - operand.replace(exp.column(operands[operand], quoted=True)) + extract_agg_operands(e) else: projections.append(e) @@ -157,6 +158,13 @@ def from_expression( aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name + + having = expression.args.get("having") + + if having: + extract_agg_operands(having) + aggregate.condition = having.this + aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) @@ -173,11 +181,6 @@ def from_expression( aggregate.add_dependency(step) step = aggregate - having = expression.args.get("having") - - if having: - step.condition = having.this - order = expression.args.get("order") if order: @@ -339,6 +342,9 @@ def _to_s(self, indent: str) -> t.List[str]: lines.append(f"{indent}Group:") for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") + if self.condition: + lines.append(f"{indent}Having:") + lines.append(f"{indent} - {self.condition.sql()}") if self.operands: lines.append(f"{indent}Operands:") for expression in self.operands: diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 8138b11cb8..48937433fd 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -666,11 +666,20 @@ WITH "supplier_2" AS ( FROM "nation" AS "nation" WHERE "nation"."n_name" = 'GERMANY' +), "_u_0" AS ( + SELECT + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" + FROM "partsupp" AS "partsupp" + JOIN "supplier_2" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" + JOIN "nation_2" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" ) SELECT "partsupp"."ps_partkey" AS "ps_partkey", SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" FROM "partsupp" AS "partsupp" +CROSS JOIN "_u_0" AS "_u_0" JOIN "supplier_2" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" JOIN "nation_2" AS "nation" @@ -678,15 +687,7 @@ JOIN "nation_2" AS "nation" GROUP BY "partsupp"."ps_partkey" HAVING - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( - SELECT - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" - FROM "partsupp" AS "partsupp" - JOIN "supplier_2" AS "supplier" - ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" - JOIN "nation_2" AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" - ) + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0") ORDER BY "value" DESC; @@ -880,6 +881,10 @@ WITH "revenue" AS ( AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) GROUP BY "lineitem"."l_suppkey" +), "_u_0" AS ( + SELECT + MAX("revenue"."total_revenue") AS "_col_0" + FROM "revenue" ) SELECT "supplier"."s_suppkey" AS "s_suppkey", @@ -889,12 +894,9 @@ SELECT "revenue"."total_revenue" AS "total_revenue" FROM "supplier" AS "supplier" JOIN "revenue" - ON "revenue"."total_revenue" = ( - SELECT - MAX("revenue"."total_revenue") AS "_col_0" - FROM "revenue" - ) - AND "supplier"."s_suppkey" = "revenue"."supplier_no" + ON "supplier"."s_suppkey" = "revenue"."supplier_no" +JOIN "_u_0" AS "_u_0" + ON "revenue"."total_revenue" = "_u_0"."_col_0" ORDER BY "s_suppkey"; @@ -1395,7 +1397,14 @@ order by cntrycode; WITH "_u_0" AS ( SELECT - "orders"."o_custkey" AS "_u_1" + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') +), "_u_1" AS ( + SELECT + "orders"."o_custkey" AS "_u_2" FROM "orders" AS "orders" GROUP BY "orders"."o_custkey" @@ -1405,18 +1414,12 @@ SELECT COUNT(*) AS "numcust", SUM("customer"."c_acctbal") AS "totacctbal" FROM "customer" AS "customer" -LEFT JOIN "_u_0" AS "_u_0" - ON "_u_0"."_u_1" = "customer"."c_custkey" +JOIN "_u_0" AS "_u_0" + ON "customer"."c_acctbal" > "_u_0"."_col_0" +LEFT JOIN "_u_1" AS "_u_1" + ON "_u_1"."_u_2" = "customer"."c_custkey" WHERE - "_u_0"."_u_1" IS NULL - AND "customer"."c_acctbal" > ( - SELECT - AVG("customer"."c_acctbal") AS "_col_0" - FROM "customer" AS "customer" - WHERE - "customer"."c_acctbal" > 0.00 - AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') - ) + "_u_1"."_u_2" IS NULL AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') GROUP BY SUBSTRING("customer"."c_phone", 1, 2) diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index f53121a902..dc373a0cad 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -1,10 +1,12 @@ +--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x; -------------------------------------- -- Unnest Subqueries -------------------------------------- SELECT * FROM x AS x WHERE - x.a IN (SELECT y.a AS a FROM y) + x.a = (SELECT SUM(y.a) AS a FROM y) + AND x.a IN (SELECT y.a AS a FROM y) AND x.a IN (SELECT y.b AS b FROM y) AND x.a = ANY (SELECT y.a AS a FROM y) AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) @@ -24,52 +26,57 @@ WHERE SELECT * FROM x AS x +CROSS JOIN ( + SELECT + SUM(y.a) AS a + FROM y +) AS "_u_0" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_0" - ON x.a = "_u_0"."a" +) AS "_u_1" + ON x.a = "_u_1"."a" LEFT JOIN ( SELECT y.b AS b FROM y GROUP BY y.b -) AS "_u_1" - ON x.a = "_u_1"."b" +) AS "_u_2" + ON x.a = "_u_2"."b" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_2" - ON x.a = "_u_2"."a" +) AS "_u_3" + ON x.a = "_u_3"."a" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_4 + y.a AS _u_5 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_3" - ON x.a = "_u_3"."_u_4" +) AS "_u_4" + ON x.a = "_u_4"."_u_5" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_6 + y.a AS _u_7 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_5" - ON x.a = "_u_5"."_u_6" +) AS "_u_6" + ON x.a = "_u_6"."_u_7" LEFT JOIN ( SELECT y.a AS a @@ -78,8 +85,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_7" - ON "_u_7".a = x.a +) AS "_u_8" + ON "_u_8".a = x.a LEFT JOIN ( SELECT y.a AS a @@ -88,31 +95,31 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_8" - ON "_u_8".a = x.a +) AS "_u_9" + ON "_u_9".a = x.a LEFT JOIN ( SELECT ARRAY_AGG(y.a) AS a, - y.b AS _u_10 + y.b AS _u_11 FROM y WHERE TRUE GROUP BY y.b -) AS "_u_9" - ON "_u_9"."_u_10" = x.a +) AS "_u_10" + ON "_u_10"."_u_11" = x.a LEFT JOIN ( SELECT SUM(y.a) AS a, - y.a AS _u_12, - ARRAY_AGG(y.b) AS _u_13 + y.a AS _u_13, + ARRAY_AGG(y.b) AS _u_14 FROM y WHERE TRUE AND TRUE AND TRUE GROUP BY y.a -) AS "_u_11" - ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b +) AS "_u_12" + ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b LEFT JOIN ( SELECT y.a AS a @@ -121,37 +128,38 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_14" - ON x.a = "_u_14".a +) AS "_u_15" + ON x.a = "_u_15".a WHERE - NOT "_u_0"."a" IS NULL - AND NOT "_u_1"."b" IS NULL - AND NOT "_u_2"."a" IS NULL + x.a = "_u_0".a + AND NOT "_u_1"."a" IS NULL + AND NOT "_u_2"."b" IS NULL + AND NOT "_u_3"."a" IS NULL AND ( - x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL + x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL ) AND ( - x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL + x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL ) AND ( - None = "_u_7".a AND NOT "_u_7".a IS NULL + None = "_u_8".a AND NOT "_u_8".a IS NULL ) AND NOT ( - x.a = "_u_8".a AND NOT "_u_8".a IS NULL + x.a = "_u_9".a AND NOT "_u_9".a IS NULL ) AND ( - ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL + ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL ) AND ( ( ( - x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL - ) AND NOT "_u_11"."_u_12" IS NULL + x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL + ) AND NOT "_u_12"."_u_13" IS NULL ) - AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) + AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d) ) AND ( - NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL + NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL ) AND x.a IN ( SELECT diff --git a/tests/test_executor.py b/tests/test_executor.py index b596a7d6d7..caf3b2ac03 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -74,7 +74,7 @@ def to_csv(expression): ) return expression - for i, (sql, _) in enumerate(self.sqls[0:10]): + for i, (sql, _) in enumerate(self.sqls[0:14]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) @@ -346,6 +346,28 @@ def test_execute_tables(self): ], ) + def test_execute_subqueries(self): + tables = { + "table": [ + {"a": 1, "b": 1}, + {"a": 2, "b": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT * + FROM table + WHERE a = (SELECT MAX(a) FROM table) + """, + tables=tables, + ).rows, + [ + (2, 2), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}}