From cb06b1b12e7a14c3c2122968f31980afe3d5ee77 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Tue, 7 Nov 2023 19:48:41 +1100 Subject: [PATCH] refactor: simplify expr checking in predicate push down --- .../polars-plan/src/logical_plan/aexpr/mod.rs | 2 +- .../optimizer/predicate_pushdown/join.rs | 48 ++-- .../optimizer/predicate_pushdown/mod.rs | 41 +++- .../optimizer/predicate_pushdown/utils.rs | 227 ++++++++---------- py-polars/tests/unit/test_predicates.py | 23 ++ 5 files changed, 170 insertions(+), 171 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index b8b279da6eba6..56785ee1b5b62 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -229,7 +229,7 @@ impl AExpr { | Take { .. } | Nth(_) => true, - | Alias(_, _) + Alias(_, _) | Explode(_) | Column(_) | Literal(_) diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs index 30db762210224..644e35821c535 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs @@ -115,33 +115,31 @@ pub(super) fn process_join( let mut filter_left = false; let mut filter_right = false; - // predicate should not have an aggregation or window function as that would - // be influenced by join - #[allow(clippy::suspicious_else_formatting)] - if !predicate_is_pushdown_boundary(predicate, expr_arena) { - if check_input_node(predicate, &schema_left, expr_arena) && !block_pushdown_left { - insert_and_combine_predicate(&mut pushdown_left, predicate, expr_arena); - filter_left = true; - } + assert_aexpr_allows_predicate_pushdown(predicate, expr_arena); - // if the predicate is in the left hand side - // the right hand side should be renamed with the suffix. - // in that case we should not push down as the user wants to filter on `x` - // not on `x_rhs`. - if !filter_left - && check_input_node(predicate, &schema_right, expr_arena) - && !block_pushdown_right - // However, if we push down to the left and all predicate columns are also - // join columns, we also push down right - || filter_left - && all_pred_cols_in_left_on(predicate, expr_arena, &left_on) - // TODO: Restricting to Inner and Left Join is probably too conservative - && matches!(&options.args.how, JoinType::Inner | JoinType::Left) - { - insert_and_combine_predicate(&mut pushdown_right, predicate, expr_arena); - filter_right = true; - } + if check_input_node(predicate, &schema_left, expr_arena) && !block_pushdown_left { + insert_and_combine_predicate(&mut pushdown_left, predicate, expr_arena); + filter_left = true; } + + // if the predicate is in the left hand side + // the right hand side should be renamed with the suffix. + // in that case we should not push down as the user wants to filter on `x` + // not on `x_rhs`. + if !filter_left + && check_input_node(predicate, &schema_right, expr_arena) + && !block_pushdown_right + // However, if we push down to the left and all predicate columns are also + // join columns, we also push down right + || filter_left + && all_pred_cols_in_left_on(predicate, expr_arena, &left_on) + // TODO: Restricting to Inner and Left Join is probably too conservative + && matches!(&options.args.how, JoinType::Inner | JoinType::Left) + { + insert_and_combine_predicate(&mut pushdown_right, predicate, expr_arena); + filter_right = true; + } + match (filter_left, filter_right, &options.args.how) { // if not pushed down on one of the tables we have to do it locally. (false, false, _) | diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 40ca5d66fb6ee..45d43b996fe1f 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -75,10 +75,10 @@ impl<'a> PredicatePushDown<'a> { let exprs = lp.get_exprs(); if has_projections { - // we should not pass these projections + // This checks the exprs in the projections at this level. if exprs .iter() - .any(|e_n| projection_is_definite_pushdown_boundary(*e_n, expr_arena)) + .any(|e_n| aexpr_blocks_predicate_pushdown(*e_n, expr_arena)) { return self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena); } @@ -211,12 +211,13 @@ impl<'a> PredicatePushDown<'a> { // filter(y > 1) --> filter(x == min(x)) & filter(y > 2) // pushdown of filter(y > 2) is correctly stopped at the boundary // - // Performing this step here should guarantee that acc_predicates - // in all other contexts do not contain a mix of boundary and - // non-boundary predicates. + // Assuming all predicates originate from the `Selection` node + // at the beginning of optimization, applying this step here + // guarantees that boundary predicates will not appear in other + // contexts. Note boundary projections are handled elsewhere. let local_predicates = if acc_predicates .values() - .any(|node| predicate_is_pushdown_boundary(*node, expr_arena)) + .any(|node| aexpr_blocks_predicate_pushdown(*node, expr_arena)) { let local_predicates = acc_predicates.values().copied().collect::>(); acc_predicates.clear(); @@ -260,13 +261,29 @@ impl<'a> PredicatePushDown<'a> { file_options: options, output_schema } => { - let mut local_predicates = partition_by_full_context(&mut acc_predicates, expr_arena); - if let Some(ref row_count) = options.row_count{ - let row_count_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, |name| { - name.as_ref() == row_count.name - }); - local_predicates.extend_from_slice(&row_count_predicates); + for node in acc_predicates.values() { + assert_aexpr_allows_predicate_pushdown(*node, expr_arena); } + + let local_predicates = match &scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => vec![], + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => vec![], + _ => { + // Disallow row-count pushdown of other scans as they may + // not update the row counts properly before applying the + // predicate (e.g. FileScan::Csv doesn't). + if let Some(ref row_count) = options.row_count { + let row_count_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, |name| { + name.as_ref() == row_count.name + }); + row_count_predicates + } else { + vec![] + } + } + }; let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); if let (true, Some(predicate)) = (file_info.hive_parts.is_some(), predicate) { diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index e1b4d50804042..365000cb0b794 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -104,100 +104,25 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) has_aexpr(node, expr_arena, matches) } -// this checks if a predicate from a node upstream can pass -// the predicate in this filter -// Cases where this cannot be the case: -// -// .filter(a > 1) # filter 2 -///.filter(a == min(a)) # filter 1 +/// Predicates can be renamed during pushdown to support being pushed through +/// aliases, however this is permitted only if the alias is not preceded by any +/// operations that change the column values. For example: /// -/// the min(a) is influenced by filter 2 so min(a) should not pass -pub(super) fn predicate_is_pushdown_boundary(node: Node, expr_arena: &Arena) -> bool { - let matches = |e: &AExpr| { - matches!( - e, - AExpr::Sort { .. } | AExpr::SortBy { .. } - | AExpr::Take{..} // A take needs all rows - | AExpr::Agg(_) // an aggregation needs all rows - // Apply groups can be something like shift, sort, or an aggregation like skew - // both need all values - | AExpr::AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, .. }, ..} - | AExpr::Function {options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, .. }, ..} - | AExpr::Explode {..} - // A group_by needs all rows for aggregation - | AExpr::Window {..} - ) - }; - has_aexpr(node, expr_arena, matches) -} - -/// Some predicates should not pass a projection if they would influence results of other columns. -/// For instance shifts | sorts results are influenced by a filter so we do all predicates before the shift | sort -/// The rule of thumb is any operation that changes the order of a column w/r/t other columns should be a -/// predicate pushdown blocker. +/// `col(A).alias(B)` - predicates referring to column B can be re-written to +/// use column A, since they have the same values. /// -/// This checks the boundary of other columns -pub(super) fn projection_is_definite_pushdown_boundary( - node: Node, - expr_arena: &Arena, -) -> bool { - let matches = |e: &AExpr| { - use AExpr::*; - // any result that will change due to rows filtered before the projection - - // explicit match is more readable in this case - #[allow(clippy::match_like_matches_macro)] - match e { - Agg(_) // an aggregation needs all rows - // Apply groups can be something like shift, sort, or an aggregation like skew - // both need all values - | AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, .. }, ..} - | Function {options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, .. }, ..} - // still need to investigate this one - | Explode {..} - | Count - | Nth(_) - | Slice {..} - | Take {..} - // A group_by needs all rows for aggregation - | Window {..} - | Literal(LiteralValue::Range {..}) => true, - // The series might be used in a comparison with exactly the right length - Literal(LiteralValue::Series(s)) => s.len() > 1, - _ => false - } - }; - has_aexpr(node, expr_arena, matches) -} - -/// This is only a boundary if a predicate refers to the projection output name. -/// This checks the boundary of same columns. -/// So that means columns that are referred in the predicate -/// for instance `predicate = col(A) == col(B).` -/// and `col().some_func().alias(B)` is projected. -/// then the projection can not pass, as column `B` maybe -/// changed by `some_func` -pub(super) fn projection_is_optional_pushdown_boundary( +/// `col(A).sort().alias(B)` - predicates referring to column B cannot be +/// re-written to use column A as they have different values. +pub(super) fn projection_allows_aliased_predicate_pushdown( node: Node, expr_arena: &Arena, ) -> bool { - let matches = |e: &AExpr| { - use AExpr::*; - // anything that changes output values modifies the predicate result - // and is not captured by function above: `projection_is_definite_pushdown_boundary` - - // explicit match is more readable in this case - #[allow(clippy::match_like_matches_macro)] - match e { - AnonymousFunction { .. } - | Function { .. } - | BinaryExpr { .. } - | Ternary { .. } - | Cast { .. } => true, - _ => false, - } - }; - has_aexpr(node, expr_arena, matches) + for (_, ae) in expr_arena.iter(node) { + if !matches!(ae, AExpr::Column(_) | AExpr::Alias(_, _)) { + return false; + }; + } + true } enum LoopBehavior { @@ -209,7 +134,7 @@ fn rename_predicate_columns_due_to_aliased_projection( expr_arena: &mut Arena, acc_predicates: &mut PlHashMap, Node>, projection_node: Node, - projection_maybe_boundary: bool, + allow_aliased_pushdown: bool, local_predicates: &mut Vec, ) -> LoopBehavior { let projection_aexpr = expr_arena.get(projection_node); @@ -225,7 +150,7 @@ fn rename_predicate_columns_due_to_aliased_projection( // if this alias refers to one of the predicates in the upper nodes // we rename the column of the predicate before we push it downwards. if let Some(predicate) = acc_predicates.remove(&alias_name) { - if projection_maybe_boundary { + if !allow_aliased_pushdown { local_predicates.push(predicate); remove_predicate_refers_to_alias(acc_predicates, local_predicates, &alias_name); return LoopBehavior::Continue; @@ -289,8 +214,8 @@ where // this may be problematic as the aliased column may not yet exist. for projection_node in &projections { // only if a predicate refers to this projection's output column. - let projection_maybe_boundary = - projection_is_optional_pushdown_boundary(*projection_node, expr_arena); + let allow_aliased_pushdown = + projection_allows_aliased_predicate_pushdown(*projection_node, expr_arena); { // if this alias refers to one of the predicates in the upper nodes @@ -299,7 +224,7 @@ where expr_arena, acc_predicates, *projection_node, - projection_maybe_boundary, + allow_aliased_pushdown, &mut local_predicates, ) { LoopBehavior::Continue => continue, @@ -312,27 +237,24 @@ where .to_field(&input_schema, Context::Default, expr_arena) .unwrap(); - // we check if predicates can be done on the input above - // this can only be done if the current projection is not a projection boundary - let is_boundary = projection_is_definite_pushdown_boundary(*projection_node, expr_arena); + // should have been handled earlier by `pushdown_and_continue`. + assert_aexpr_allows_predicate_pushdown(*projection_node, expr_arena); // remove predicates that cannot be done on the input above let to_local = acc_predicates .iter() .filter_map(|(name, predicate)| { - // there are some conditions we need to check for every predicate we try to push down - // 1. does the column exist on the node above - // 2. if the projection is a computation/transformation and the predicate is based on that column - // we must block because the predicate would be incorrect. - // 3. if applying the predicate earlier does not influence the result of this projection - // this is the case for instance with a sum operation (filtering out rows influences the result) - - // checks 1. - if check_input_node(*predicate, &input_schema, expr_arena) - // checks 2. - && !(key_has_name(name, output_field.name()) && projection_maybe_boundary) - // checks 3. - && !is_boundary + if !key_has_name(name, output_field.name()) { + // Predicate has nothing to do with this projection. + return None; + } + + if + // checks that the column does not change value compared to the + // node above + allow_aliased_pushdown + // checks that the column exists in the node above + && check_input_node(*predicate, &input_schema, expr_arena) { None } else { @@ -409,29 +331,68 @@ where local_predicates } -/// predicates that need the full context should not be pushed down to the scans -/// example: min(..) == null_count -pub(super) fn partition_by_full_context( - acc_predicates: &mut PlHashMap, Node>, - expr_arena: &Arena, -) -> Vec { - // TODO! - // Assert that acc_predicates does not contain a mix of groups sensitive and - // non-groups sensitive predicates, as this should have been handled - // earlier under push_down::match::Selection. - if acc_predicates.values().any(|node| { - has_aexpr(*node, expr_arena, |ae| match ae { - AExpr::BinaryExpr { left, right, .. } => { - expr_arena.get(*left).groups_sensitive() - || expr_arena.get(*right).groups_sensitive() - }, +/// An expression blocks predicates from being pushed past it if its results for +/// the subset where the predicate evaluates as true becomes different compared +/// to if it was performed before the predicate was applied. This is in general +/// any expression that produces outputs based on groups of values +/// (i.e. groups-wise) rather than individual values (i.e. element-wise). +/// +/// Examples of expressions whose results would change, and thus block push-down: +/// - any aggregation - sum, mean, first, last, min, max etc. +/// - sorting - as the sort keys would change between filters +pub(super) fn aexpr_blocks_predicate_pushdown(node: Node, expr_arena: &Arena) -> bool { + let mut stack = Vec::::with_capacity(4); + stack.push(node); + + // Cannot use `has_aexpr` because we need to ignore any literals in the RHS + // of an `is_in` operation. + while let Some(node) = stack.pop() { + let ae = expr_arena.get(node); + + if match ae { + // These literals do not come from the RHS of an is_in, meaning that + // they are projected as either columns or predicates, both of which + // rely on the height of the dataframe at this level and thus need + // to block pushdown. + AExpr::Literal(LiteralValue::Range { .. }) => true, + AExpr::Literal(LiteralValue::Series(s)) => s.len() > 1, ae => ae.groups_sensitive(), - }) - }) { - let local_predicates = acc_predicates.values().copied().collect::>(); - acc_predicates.clear(); - local_predicates - } else { - vec![] + } { + return true; + } + + match ae { + AExpr::Function { + function: FunctionExpr::Boolean(BooleanFunction::IsIn), + input, + .. + } => { + // Handles a special case where the expr contains a series, but it is being + // used as part of an `is_in`, so it can be pushed down. + let values = input.get(1).unwrap(); + if matches!(expr_arena.get(*values), AExpr::Literal { .. }) { + // Still need to check the Expr on the LHS of the is_in. + let node = *input.get(0).unwrap(); + stack.push(node); + expr_arena.get(node).nodes(&mut stack); + } else { + ae.nodes(&mut stack); + } + }, + ae => { + ae.nodes(&mut stack); + }, + }; } + + false +} + +/// Used in places that previously handled blocking exprs before refactoring. +/// Can probably be eventually removed if it isn't catching anything. +pub(super) fn assert_aexpr_allows_predicate_pushdown(node: Node, expr_arena: &Arena) { + assert!( + !aexpr_blocks_predicate_pushdown(node, expr_arena), + "Predicate pushdown: Did not expect blocking exprs at this point, please open an issue." + ); } diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index ec16a4067eb3e..88b852005b222 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -267,3 +267,26 @@ def test_take_can_block_predicate_pushdown() -> None: result = lf.collect(predicate_pushdown=True) expected = {"x": [2], "y": [True]} assert result.to_dict(as_series=False) == expected + + +def test_literal_series_expr_predicate_pushdown() -> None: + # No pushdown should occur in this case, because otherwise the filter will + # attempt to filter 3 rows with a boolean mask of 2 rows. + lf = ( + pl.LazyFrame({"x": [0, 1, 2]}) + .filter(pl.col("x") > 0) + .filter(pl.Series([True, True])) + ) + + assert lf.collect().to_series().to_list() == [1, 2] + + # Pushdown should occur here, because the series is being used as part of + # an `is_in`. + lf = ( + pl.LazyFrame({"x": [0, 1, 2]}) + .filter(pl.col("x") > 0) + .filter(pl.col("x").is_in([0, 1])) + ) + + assert "FILTER" not in lf.explain() + assert lf.collect().to_series().to_list() == [1]