From ac0e0861bc2f379772da9a7e75e8c7bbe27d1b75 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Mon, 27 Nov 2023 19:20:51 +1100 Subject: [PATCH] c --- .../optimizer/predicate_pushdown/utils.rs | 133 +++++++++--------- py-polars/tests/unit/test_predicates.py | 1 - 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 77fb1f3b95307..8ae44b2f9bc91 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -266,76 +266,77 @@ pub fn pushdown_eligibility( // Important: Names inserted into any data structure by this function are // all non-aliased. // This function returns false if pushdown cannot be performed. - let process_node = |ae_nodes_stack: &mut Vec, - has_window: &mut bool, - common_window_inputs: &mut PlHashSet>| { - debug_assert_eq!(ae_nodes_stack.len(), 1); - - while let Some(node) = ae_nodes_stack.pop() { - let ae = expr_arena.get(node); - - match ae { - AExpr::Window { - partition_by, - #[cfg(feature = "dynamic_group_by")] - options, - // The function is not checked for groups-sensitivity because - // it is applied over the windows. - .. - } => { - #[cfg(feature = "dynamic_group_by")] - if matches!(options, WindowType::Rolling(..)) { - return false; - }; - - let mut partition_by_names = - PlHashSet::>::with_capacity(partition_by.len()); - - for node in partition_by.iter() { - // Only accept col() or col().alias() - if let Some((_, name)) = - get_maybe_aliased_projection_to_input_name_map(*node, expr_arena) - { - partition_by_names.insert(name.clone()); - } else { - // Nested windows can also qualify for push down. - // e.g.: - // * expr1 = min().over(A) - // * expr2 = sum().over(A, expr1) - // Both exprs window over A, so predicates referring - // to A can still be pushed. - ae_nodes_stack.push(*node); - } - } + let process_projection_or_predicate = + |ae_nodes_stack: &mut Vec, + has_window: &mut bool, + common_window_inputs: &mut PlHashSet>| { + debug_assert_eq!(ae_nodes_stack.len(), 1); + + while let Some(node) = ae_nodes_stack.pop() { + let ae = expr_arena.get(node); - if !*has_window { - for name in partition_by_names.into_iter() { - common_window_inputs.insert(name); + match ae { + AExpr::Window { + partition_by, + #[cfg(feature = "dynamic_group_by")] + options, + // The function is not checked for groups-sensitivity because + // it is applied over the windows. + .. + } => { + #[cfg(feature = "dynamic_group_by")] + if matches!(options, WindowType::Rolling(..)) { + return false; + }; + + let mut partition_by_names = + PlHashSet::>::with_capacity(partition_by.len()); + + for node in partition_by.iter() { + // Only accept col() or col().alias() + if let Some((_, name)) = + get_maybe_aliased_projection_to_input_name_map(*node, expr_arena) + { + partition_by_names.insert(name.clone()); + } else { + // Nested windows can also qualify for push down. + // e.g.: + // * expr1 = min().over(A) + // * expr2 = sum().over(A, expr1) + // Both exprs window over A, so predicates referring + // to A can still be pushed. + ae_nodes_stack.push(*node); + } } - *has_window = true; - } else { - common_window_inputs.retain(|k| partition_by_names.contains(k)) - } + if !*has_window { + for name in partition_by_names.into_iter() { + common_window_inputs.insert(name); + } - // Cannot push into disjoint windows: - // e.g.: - // * sum().over(A) - // * sum().over(B) - if common_window_inputs.is_empty() { - return false; - } - }, - _ => { - if !check_and_extend_predicate_pd_nodes(ae_nodes_stack, ae, expr_arena) { - return false; - } - }, + *has_window = true; + } else { + common_window_inputs.retain(|k| partition_by_names.contains(k)) + } + + // Cannot push into disjoint windows: + // e.g.: + // * sum().over(A) + // * sum().over(B) + if common_window_inputs.is_empty() { + return false; + } + }, + _ => { + if !check_and_extend_predicate_pd_nodes(ae_nodes_stack, ae, expr_arena) { + return false; + } + }, + } } - } - true - }; + true + }; for node in projection_nodes.iter() { if let Some((alias, column_name)) = @@ -359,7 +360,7 @@ pub fn pushdown_eligibility( debug_assert!(ae_nodes_stack.is_empty()); ae_nodes_stack.push(*node); - if !process_node( + if !process_projection_or_predicate( &mut ae_nodes_stack, &mut has_window, &mut common_window_inputs, @@ -396,7 +397,7 @@ pub fn pushdown_eligibility( debug_assert!(ae_nodes_stack.is_empty()); ae_nodes_stack.push(*node); - if !process_node( + if !process_projection_or_predicate( &mut ae_nodes_stack, &mut has_window, &mut common_window_inputs, diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 6015244c81021..2d454484ec8fc 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -355,7 +355,6 @@ def test_predicate_pushdown_with_window_projections_12637() -> None: ) plan = actual.explain() - print(plan) assert "FILTER" in plan assert r'SELECTION: "[(col(\"key\")) == (5)]"' in plan