Skip to content

Commit

Permalink
fix(sql): look for CTEs under value expressions as well (#8633)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
kszucs and cpcloud authored Mar 14, 2024
1 parent 220085e commit 14358fe
Show file tree
Hide file tree
Showing 10 changed files with 543 additions and 952 deletions.
5 changes: 3 additions & 2 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import var
from ibis.common.graph import Graph
from ibis.common.patterns import Object, Pattern, _, replace
from ibis.common.patterns import InstanceOf, Object, Pattern, _, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import d, p, replace_parameter
from ibis.expr.schema import Schema
Expand Down Expand Up @@ -184,8 +184,9 @@ def merge_select_select(_, **kwargs):
def extract_ctes(node):
result = []
cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample)
dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar)

g = Graph.from_bfs(node, filter=(ops.Relation, ops.Subquery, ops.JoinLink))
g = Graph.from_bfs(node, filter=~InstanceOf(dont_count))
for node, dependents in g.invert().items():
if isinstance(node, ops.View) or (
len(dependents) > 1 and isinstance(node, cte_types)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
SELECT
"t8"."c_custkey",
"t8"."c_name",
"t8"."c_address",
"t8"."c_nationkey",
"t8"."c_phone",
"t8"."c_acctbal",
"t8"."c_mktsegment",
"t8"."c_comment",
"t8"."region",
"t8"."amount",
"t8"."odate"
FROM (
WITH "t8" AS (
SELECT
"t6"."c_custkey",
"t6"."c_name",
Expand All @@ -30,47 +18,40 @@ FROM (
ON "t6"."c_nationkey" = "t5"."n_nationkey"
INNER JOIN "tpch_orders" AS "t7"
ON "t7"."o_custkey" = "t6"."c_custkey"
) AS "t8"
)
SELECT
"t9"."c_custkey",
"t9"."c_name",
"t9"."c_address",
"t9"."c_nationkey",
"t9"."c_phone",
"t9"."c_acctbal",
"t9"."c_mktsegment",
"t9"."c_comment",
"t9"."region",
"t9"."amount",
"t9"."odate"
FROM "t8" AS "t9"
WHERE
"t8"."amount" > (
"t9"."amount" > (
SELECT
AVG("t10"."amount") AS "Mean(amount)"
AVG("t11"."amount") AS "Mean(amount)"
FROM (
SELECT
"t9"."c_custkey",
"t9"."c_name",
"t9"."c_address",
"t9"."c_nationkey",
"t9"."c_phone",
"t9"."c_acctbal",
"t9"."c_mktsegment",
"t9"."c_comment",
"t9"."region",
"t9"."amount",
"t9"."odate"
FROM (
SELECT
"t6"."c_custkey",
"t6"."c_name",
"t6"."c_address",
"t6"."c_nationkey",
"t6"."c_phone",
"t6"."c_acctbal",
"t6"."c_mktsegment",
"t6"."c_comment",
"t4"."r_name" AS "region",
"t7"."o_totalprice" AS "amount",
CAST("t7"."o_orderdate" AS TIMESTAMP) AS "odate"
FROM "tpch_region" AS "t4"
INNER JOIN "tpch_nation" AS "t5"
ON "t4"."r_regionkey" = "t5"."n_regionkey"
INNER JOIN "tpch_customer" AS "t6"
ON "t6"."c_nationkey" = "t5"."n_nationkey"
INNER JOIN "tpch_orders" AS "t7"
ON "t7"."o_custkey" = "t6"."c_custkey"
) AS "t9"
"t10"."c_custkey",
"t10"."c_name",
"t10"."c_address",
"t10"."c_nationkey",
"t10"."c_phone",
"t10"."c_acctbal",
"t10"."c_mktsegment",
"t10"."c_comment",
"t10"."region",
"t10"."amount",
"t10"."odate"
FROM "t8" AS "t10"
WHERE
"t9"."region" = "t8"."region"
) AS "t10"
"t10"."region" = "t9"."region"
) AS "t11"
)
LIMIT 10
Loading

0 comments on commit 14358fe

Please sign in to comment.