Skip to content

Commit

Permalink
fix(sql): avoid excessive inlining during Select merge (#8825)
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored Apr 7, 2024
1 parent 1237fe3 commit ba931da
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 266 deletions.
44 changes: 33 additions & 11 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ def first_to_firstvalue(_, **kwargs):
return _.copy(func=klass(_.func.arg))


def complexity(node):
"""Assign a complexity score to a node.
Subsequent projections can be merged into a single projection by replacing
the fields referenced in the outer projection with the computed expressions
from the inner projection. This inlining can result in very complex value
expressions depending on the projections. In order to prevent excessive
inlining, we assign a complexity score to each node.
The complexity score assigns 1 to each value expression and adds up in the
tree hierarchy unless there is a Field node where we don't add up the
complexity of the referenced relation. This way we treat fields kind of like
reusable variables considering them less complex than they were inlined.
"""

def accum(node, *args):
if isinstance(node, ops.Field):
return 1
else:
return 1 + sum(args)

return node.map_nodes(accum)[node]


@replace(Object(Select, Object(Select)))
def merge_select_select(_, **kwargs):
"""Merge subsequent Select relations into one.
Expand All @@ -128,15 +152,11 @@ def merge_select_select(_, **kwargs):
from the inner Select are inlined into the outer Select.
"""
# don't merge if either the outer or the inner select has window functions
for v in _.selections.values():
if v.find(ops.WindowFunction, filter=ops.Value):
return _
for v in _.parent.selections.values():
if v.find((ops.WindowFunction, ops.Unnest), filter=ops.Value):
return _
for v in _.predicates:
if v.find((ops.ExistsSubquery, ops.InSubquery), filter=ops.Value):
return _
blocking = (ops.WindowFunction, ops.ExistsSubquery, ops.InSubquery, ops.Unnest)
if _.find_below(blocking, filter=ops.Value):
return _
if _.parent.find_below(blocking, filter=ops.Value):
return _

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}
Expand All @@ -151,12 +171,13 @@ def merge_select_select(_, **kwargs):
)
unique_sort_keys = sort_keys + parent_sort_keys

return Select(
result = Select(
_.parent.parent,
selections=selections,
predicates=unique_predicates,
sort_keys=unique_sort_keys,
)
return result if complexity(result) <= complexity(_) else _


def extract_ctes(node):
Expand Down Expand Up @@ -198,7 +219,8 @@ def sqlize(
assert isinstance(node, ops.Relation)

# apply the backend specific rewrites
node = node.replace(reduce(operator.or_, rewrites))
if rewrites:
node = node.replace(reduce(operator.or_, rewrites))

# lower the expression graph to a SQL-like relational algebra
context = {"params": params}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
SELECT
IIF(
IIF([t2].[InSubquery(x)] <> 0, 1, 0) AS [InSubquery(x)]
FROM (
SELECT
[t0].[x] IN (
SELECT
[t0].[x]
FROM [t] AS [t0]
WHERE
[t0].[x] > 2
),
1,
0
) AS [InSubquery(x)]
FROM [t] AS [t0]
) AS [InSubquery(x)]
FROM [t] AS [t0]
) AS [t2]
5 changes: 5 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,11 @@ def test_typeof(con):
reason="https://github.com/risingwavelabs/risingwave/issues/1343",
)
@pytest.mark.xfail_version(dask=["dask<2024.2.0"])
@pytest.mark.notyet(
["mssql"],
raises=PyODBCProgrammingError,
reason="naked IN queries are not supported",
)
def test_isin_uncorrelated(
backend, batting, awards_players, batting_df, awards_players_df
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,68 +1,90 @@
SELECT
"t9"."s_name",
"t9"."s_address"
"t13"."s_name",
"t13"."s_address"
FROM (
SELECT
"t5"."s_suppkey",
"t5"."s_name",
"t5"."s_address",
"t5"."s_nationkey",
"t5"."s_phone",
"t5"."s_acctbal",
"t5"."s_comment",
"t6"."n_nationkey",
"t6"."n_name",
"t6"."n_regionkey",
"t6"."n_comment"
FROM "supplier" AS "t5"
INNER JOIN "nation" AS "t6"
ON "t5"."s_nationkey" = "t6"."n_nationkey"
) AS "t9"
WHERE
"t9"."n_name" = 'CANADA'
AND "t9"."s_suppkey" IN (
"t9"."s_suppkey",
"t9"."s_name",
"t9"."s_address",
"t9"."s_nationkey",
"t9"."s_phone",
"t9"."s_acctbal",
"t9"."s_comment",
"t9"."n_nationkey",
"t9"."n_name",
"t9"."n_regionkey",
"t9"."n_comment"
FROM (
SELECT
"t1"."ps_suppkey"
FROM "partsupp" AS "t1"
WHERE
"t1"."ps_partkey" IN (
"t5"."s_suppkey",
"t5"."s_name",
"t5"."s_address",
"t5"."s_nationkey",
"t5"."s_phone",
"t5"."s_acctbal",
"t5"."s_comment",
"t6"."n_nationkey",
"t6"."n_name",
"t6"."n_regionkey",
"t6"."n_comment"
FROM "supplier" AS "t5"
INNER JOIN "nation" AS "t6"
ON "t5"."s_nationkey" = "t6"."n_nationkey"
) AS "t9"
WHERE
"t9"."n_name" = 'CANADA'
AND "t9"."s_suppkey" IN (
SELECT
"t11"."ps_suppkey"
FROM (
SELECT
"t3"."p_partkey"
FROM "part" AS "t3"
"t2"."ps_partkey",
"t2"."ps_suppkey",
"t2"."ps_availqty",
"t2"."ps_supplycost",
"t2"."ps_comment"
FROM "partsupp" AS "t2"
WHERE
"t3"."p_name" LIKE 'forest%'
)
AND "t1"."ps_availqty" > (
(
SELECT
SUM("t8"."l_quantity") AS "Sum(l_quantity)"
FROM (
"t2"."ps_partkey" IN (
SELECT
"t4"."l_orderkey",
"t4"."l_partkey",
"t4"."l_suppkey",
"t4"."l_linenumber",
"t4"."l_quantity",
"t4"."l_extendedprice",
"t4"."l_discount",
"t4"."l_tax",
"t4"."l_returnflag",
"t4"."l_linestatus",
"t4"."l_shipdate",
"t4"."l_commitdate",
"t4"."l_receiptdate",
"t4"."l_shipinstruct",
"t4"."l_shipmode",
"t4"."l_comment"
FROM "lineitem" AS "t4"
"t3"."p_partkey"
FROM "part" AS "t3"
WHERE
"t4"."l_partkey" = "t1"."ps_partkey"
AND "t4"."l_suppkey" = "t1"."ps_suppkey"
AND "t4"."l_shipdate" >= MAKE_DATE(1994, 1, 1)
AND "t4"."l_shipdate" < MAKE_DATE(1995, 1, 1)
) AS "t8"
) * CAST(0.5 AS DOUBLE)
)
)
"t3"."p_name" LIKE 'forest%'
)
AND "t2"."ps_availqty" > (
(
SELECT
SUM("t8"."l_quantity") AS "Sum(l_quantity)"
FROM (
SELECT
"t4"."l_orderkey",
"t4"."l_partkey",
"t4"."l_suppkey",
"t4"."l_linenumber",
"t4"."l_quantity",
"t4"."l_extendedprice",
"t4"."l_discount",
"t4"."l_tax",
"t4"."l_returnflag",
"t4"."l_linestatus",
"t4"."l_shipdate",
"t4"."l_commitdate",
"t4"."l_receiptdate",
"t4"."l_shipinstruct",
"t4"."l_shipmode",
"t4"."l_comment"
FROM "lineitem" AS "t4"
WHERE
"t4"."l_partkey" = "t2"."ps_partkey"
AND "t4"."l_suppkey" = "t2"."ps_suppkey"
AND "t4"."l_shipdate" >= MAKE_DATE(1994, 1, 1)
AND "t4"."l_shipdate" < MAKE_DATE(1995, 1, 1)
) AS "t8"
) * CAST(0.5 AS DOUBLE)
)
) AS "t11"
)
) AS "t13"
ORDER BY
"t9"."s_name" ASC
"t13"."s_name" ASC
Loading

0 comments on commit ba931da

Please sign in to comment.