Skip to content

Commit

Permalink
refactor(analysis): simplify and improve pushdown_selection_filters()
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Oct 17, 2023
1 parent 1313e0c commit 2e47738
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 163 deletions.
9 changes: 4 additions & 5 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,11 @@ def exists_subquery(translator, op):
ctx = translator.context

dummy = ir.literal(1).name("")

filtered = op.foreign_table.to_expr().filter(
[pred.to_expr() for pred in op.predicates]
node = ops.Selection(
table=op.foreign_table,
selections=[dummy],
predicates=op.predicates,
)
node = filtered.select(dummy).op()

subquery = ctx.get_compiled_expr(node)

return f"EXISTS (\n{util.indent(subquery, ctx.indent)}\n)"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
SELECT t0.*
FROM (
SELECT t1.*
FROM `t0` t1
WHERE t1.`a` < 100
) t0
WHERE t0.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM `t0` t1
WHERE t1.`a` < 100
)
FROM `t0` t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM `t0` t0
WHERE t0.`a` < 100
))
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
WITH t0 AS (
SELECT t1.*
FROM `t0` t1
WHERE t1.`a` < 100
)
SELECT t0.*
FROM t0
WHERE (t0.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM `t0` t1
WHERE t1.`a` < 100
FROM `t0` t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM `t0` t0
WHERE t0.`a` < 100
)) AND
(t0.`b` = 'a')
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ WITH t0 AS (
ON t6.o_orderkey = t2.l_orderkey
JOIN main.nation AS t7
ON t3.s_nationkey = t7.n_nationkey
WHERE
t5.p_name LIKE '%green%'
)
SELECT
t1.nation,
Expand All @@ -28,8 +30,6 @@ FROM (
t0.o_year AS o_year,
SUM(t0.amount) AS sum_profit
FROM t0
WHERE
t0.p_name LIKE '%green%'
GROUP BY
1,
2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ WITH t0 AS (
ON t4."o_orderkey" = t0."l_orderkey"
JOIN t5
ON t1."s_nationkey" = t5."n_nationkey"
WHERE
t3."p_name" LIKE '%green%'
)
SELECT
t7."nation",
Expand All @@ -96,8 +98,6 @@ FROM (
t6."o_year" AS "o_year",
SUM(t6."amount") AS "sum_profit"
FROM t6
WHERE
t6."p_name" LIKE '%green%'
GROUP BY
1,
2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ WITH t0 AS (
ON t6.o_orderkey = t2.l_orderkey
JOIN "hive".ibis_sf1.nation AS t7
ON t3.s_nationkey = t7.n_nationkey
WHERE
t5.p_name LIKE '%green%'
)
SELECT
t1.nation,
Expand All @@ -28,8 +30,6 @@ FROM (
t0.o_year AS o_year,
SUM(t0.amount) AS sum_profit
FROM t0
WHERE
t0.p_name LIKE '%green%'
GROUP BY
1,
2
Expand Down
104 changes: 41 additions & 63 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ibis.common.annotations import ValidationError
from ibis.common.deferred import _, deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import pattern
from ibis.common.patterns import Eq, In, pattern
from ibis.util import Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -196,69 +196,39 @@ def apply_filter(op, predicates):
return ops.Selection(op, [], predicates)


def pushdown_selection_filters(op, predicates):
default = ops.Selection(op, selections=[], predicates=predicates)
def pushdown_selection_filters(parent, predicates):
default = ops.Selection(parent, selections=[], predicates=predicates)

# We can't push down filters on Unnest or Window because they
# change the shape and potential values of the data.
if any(
isinstance(
sel.arg if isinstance(sel, ops.Alias) else sel,
(ops.Unnest, ops.Window),
)
for sel in op.selections
):
return default

# if any of the filter predicates have the parent expression among
# their roots, then pushdown (at least of that predicate) is not
# possible

# It's not unusual for the filter to reference the projection
# itself. If a predicate can be pushed down, in this case we must
# rewrite replacing the table refs with the roots internal to the
# projection we are referencing
#
# Assuming that the fields referenced by the filter predicate originate
# below the projection, we need to rewrite the predicate referencing
# the parent tables in the join being projected

# Potential fusion opportunity. The predicates may need to be
# rewritten in terms of the child table. This prevents the broken
# ref issue (described in more detail in #59)
try:
simplified_predicates = tuple(
sub_for(predicate, {op: op.table})
if not is_reduction(predicate)
else predicate
for predicate in predicates
)
except IntegrityError:
return default

if not shares_all_roots(simplified_predicates, op.table):
return default

# find spuriously simplified predicates
for predicate in simplified_predicates:
# find columns in the predicate
depends_on = predicate.find((ops.TableColumn, ops.Literal))
for projection in op.selections:
if not isinstance(projection, (ops.TableColumn, ops.Literal)):
# if the projection's table columns overlap with columns
# used in the predicate then we return immediately
#
# this means that we were too aggressive during simplification
# example: t.mutate(a=_.a + 1).filter(_.a > 1)
if projection.find((ops.TableColumn, ops.Literal)) & depends_on:
return default

return ops.Selection(
op.table,
selections=op.selections,
predicates=op.predicates + simplified_predicates,
sort_keys=op.sort_keys,
)
projected_column_names = set()
for value in parent.selections:
if isinstance(value, (ops.Relation, ops.TableColumn)):
# we are only interested in projected value expressions, not tables
# nor column references which are not changing the projection
continue
elif value.find((ops.Reduction, ops.Analytic), filter=ops.Value):
# the parent has analytic projections like window functions so we
# can't push down filters to that level
return default
else:
# otherwise collect the names of newly projected value expressions
# which are not just plain column references
projected_column_names.add(value.name)

conflicting_projection = p.TableColumn(parent, In(projected_column_names))
pushdown_pattern = Eq(parent) >> parent.table

simplified = []
for pred in predicates:
if pred.match(conflicting_projection, filter=p.Value):
return default
try:
simplified.append(pred.replace(pushdown_pattern))
except (IntegrityError, IbisTypeError):
# former happens when there is a duplicate column name in the parent
# which is a join, the latter happens for semi/anti joins
return default

return parent.copy(predicates=parent.predicates + tuple(simplified))


def pushdown_aggregation_filters(op, predicates):
Expand Down Expand Up @@ -391,6 +361,14 @@ def __init__(self, parent, proj_exprs):
def get_result(self):
roots = find_immediate_parent_tables(self.parent.op())
first_root = roots[0]
parent_op = self.parent.op()

# reprojection of the same selections
if len(self.clean_exprs) == 1:
first = self.clean_exprs[0].op()
if isinstance(first, ops.Selection):
if first.selections == parent_op.selections:
return parent_op

if len(roots) == 1 and isinstance(first_root, ops.Selection):
fused_op = self.try_fusion(first_root)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
WITH t0 AS (
SELECT t1.*
FROM my_table t1
WHERE t1.`a` < 100
)
SELECT t0.*
FROM t0
WHERE (t0.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM my_table t1
WHERE t1.`a` < 100
FROM my_table t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM my_table t0
WHERE t0.`a` < 100
)) AND
(t0.`b` = 'a')
22 changes: 9 additions & 13 deletions ibis/tests/sql/snapshots/test_compiler/test_agg_filter/out.sql
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
WITH t0 AS (
SELECT t2.*, t2.`b` * 2 AS `b2`
FROM my_table t2
),
t1 AS (
SELECT t0.`a`, t0.`b2`
FROM t0
WHERE t0.`a` < 100
SELECT t1.*, t1.`b` * 2 AS `b2`
FROM my_table t1
)
SELECT t1.*
FROM t1
WHERE t1.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM t1
)
SELECT t0.`a`, t0.`b2`
FROM t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM t0
))
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
WITH t0 AS (
SELECT t2.*, t2.`b` * 2 AS `b2`
FROM my_table t2
),
t1 AS (
SELECT t0.`a`, t0.`b2`
FROM t0
WHERE t0.`a` < 100
SELECT t1.*, t1.`b` * 2 AS `b2`
FROM my_table t1
)
SELECT t1.*
FROM t1
WHERE t1.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM t1
)
SELECT t0.`a`, t0.`b2`
FROM t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM t0
))
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
SELECT t0.*
FROM (
SELECT t1.*
FROM my_table t1
WHERE t1.`a` < 100
) t0
WHERE t0.`a` = (
SELECT max(t1.`a`) AS `Max(a)`
FROM my_table t1
WHERE t1.`a` < 100
)
FROM my_table t0
WHERE (t0.`a` < 100) AND
(t0.`a` = (
SELECT max(t0.`a`) AS `Max(a)`
FROM my_table t0
WHERE t0.`a` < 100
))
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
SELECT t0.*
FROM (
SELECT *
FROM (
SELECT t2.*
FROM t t2
WHERE (lower(t2.`color`) LIKE '%de%') AND
(locate('de', lower(t2.`color`)) - 1 >= 0)
) t1
) t0
WHERE regexp_like(lower(t0.`color`), '.*ge.*')
FROM t t0
WHERE (lower(t0.`color`) LIKE '%de%') AND
(locate('de', lower(t0.`color`)) - 1 >= 0) AND
(regexp_like(lower(t0.`color`), '.*ge.*'))
Loading

0 comments on commit 2e47738

Please sign in to comment.