diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index a345fa9ca2b5..3b16983ff976 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -181,31 +181,21 @@ def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | return [table_expr] + exprs -def apply_filter(op, predicates): - # This will attempt predicate pushdown in the cases where we can do it - # easily and safely, to make both cleaner SQL and fewer referential errors - # for users +def pushdown_selection_filters(parent, predicates): if not predicates: - return op - - if isinstance(op, ops.Selection): - return pushdown_selection_filters(op, predicates) - elif isinstance(op, ops.Aggregation): - return pushdown_aggregation_filters(op, predicates) - else: - return ops.Selection(op, [], predicates) + return parent - -def pushdown_selection_filters(parent, predicates): default = ops.Selection(parent, selections=[], predicates=predicates) + if not isinstance(parent, (ops.Selection, ops.Aggregation)): + return default projected_column_names = set() - for value in parent.selections: + for value in parent._projection.selections: if isinstance(value, (ops.Relation, ops.TableColumn)): # we are only interested in projected value expressions, not tables # nor column references which are not changing the projection continue - elif value.find((ops.Reduction, ops.Analytic), filter=ops.Value): + elif value.find((ops.WindowFunction, ops.ExistsSubquery), filter=ops.Value): # the parent has analytic projections like window functions so we # can't push down filters to that level return default @@ -231,32 +221,6 @@ def pushdown_selection_filters(parent, predicates): return parent.copy(predicates=parent.predicates + tuple(simplified)) -def pushdown_aggregation_filters(op, predicates): - # Potential fusion opportunity - # GH1344: We can't sub in things with correlated subqueries - simplified_predicates = tuple( - # Originally this line tried substituting op.table in for expr, but - # that is too aggressive in the presence of filters that occur - # after aggregations. - # - # See https://github.com/ibis-project/ibis/pull/3341 for details - sub_for(predicate, {op.table: op}) if not is_reduction(predicate) else predicate - for predicate in predicates - ) - - if shares_all_roots(simplified_predicates, op.table): - return ops.Aggregation( - op.table, - op.metrics, - by=op.by, - having=op.having, - predicates=op.predicates + simplified_predicates, - sort_keys=op.sort_keys, - ) - else: - return ops.Selection(op, [], predicates) - - def windowize_function(expr, default_frame, merge_frames=False): func, frame = var("func"), var("frame") @@ -464,6 +428,8 @@ def _find_projections(node): return g.proceed, node._projection elif isinstance(node, ops.SelfReference): return g.proceed, node + elif isinstance(node, ops.Aggregation): + return g.proceed, node._projection elif isinstance(node, ops.Join): return g.proceed, None elif isinstance(node, ops.TableNode): diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 7d293a6969e9..ec0f5404007e 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -531,6 +531,10 @@ def __init__(self, table, metrics, by, having, predicates, sort_keys): sort_keys=sort_keys, ) + @attribute + def _projection(self): + return Projection(self.table, self.metrics + self.by) + @attribute def schema(self): names, types = [], [] diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt index 0648ffe1b86d..69c7d1add031 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt @@ -10,18 +10,21 @@ r1 := Aggregation[r0] by: region: r0.region kind: r0.kind - -r2 := Selection[r1] predicates: - r1.kind == 'foo' + r0.kind == 'foo' -r3 := Selection[r1] +r2 := Aggregation[r0] + metrics: + total: Sum(r0.amount) + by: + region: r0.region + kind: r0.kind predicates: - r1.kind == 'bar' + r0.kind == 'bar' -r4 := InnerJoin[r2, r3] r2.region == r3.region +r3 := InnerJoin[r1, r2] r1.region == r2.region -Selection[r4] +Selection[r3] selections: - r2 - right_total: r3.total \ No newline at end of file + r1 + right_total: r2.total \ No newline at end of file diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index efd0d6480c55..1fad6a3d224e 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2281,7 +2281,8 @@ def filter( import ibis.expr.analysis as an resolved_predicates = _resolve_predicates(self, predicates) - return an.apply_filter(self.op(), resolved_predicates).to_expr() + relation = an.pushdown_selection_filters(self.op(), resolved_predicates) + return relation.to_expr() def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of unique rows in the table. diff --git a/ibis/tests/sql/snapshots/test_select_sql/test_filter_self_join_analysis_bug/result.sql b/ibis/tests/sql/snapshots/test_select_sql/test_filter_self_join_analysis_bug/result.sql index 6e1a22e50844..84266c91887a 100644 --- a/ibis/tests/sql/snapshots/test_select_sql/test_filter_self_join_analysis_bug/result.sql +++ b/ibis/tests/sql/snapshots/test_select_sql/test_filter_self_join_analysis_bug/result.sql @@ -1,19 +1,16 @@ WITH t0 AS ( - SELECT t3.`region`, t3.`kind`, sum(t3.`amount`) AS `total` - FROM purchases t3 + SELECT t2.`region`, t2.`kind`, sum(t2.`amount`) AS `total` + FROM purchases t2 + WHERE t2.`kind` = 'bar' GROUP BY 1, 2 ), t1 AS ( - SELECT t0.* - FROM t0 - WHERE t0.`kind` = 'bar' -), -t2 AS ( - SELECT t0.* - FROM t0 - WHERE t0.`kind` = 'foo' + SELECT t2.`region`, t2.`kind`, sum(t2.`amount`) AS `total` + FROM purchases t2 + WHERE t2.`kind` = 'foo' + GROUP BY 1, 2 ) -SELECT t2.`region`, t2.`total` - t1.`total` AS `diff` -FROM t2 - INNER JOIN t1 - ON t2.`region` = t1.`region` \ No newline at end of file +SELECT t1.`region`, t1.`total` - t0.`total` AS `diff` +FROM t1 + INNER JOIN t0 + ON t1.`region` = t0.`region` \ No newline at end of file