Skip to content

Commit

Permalink
fix[rust]: ensure all predicates use same key function when inserting…
Browse files Browse the repository at this point in the history
… in hashmap (#5034)
  • Loading branch information
ritchie46 authored Sep 30, 2022
1 parent c3fe475 commit 0e6bff2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use utils::*;
use super::*;
use crate::dsl::function_expr::FunctionExpr;
use crate::logical_plan::{optimizer, Context};
use crate::utils::{aexpr_to_leaf_names, aexprs_to_schema, check_input_node, has_aexpr};
use crate::utils::{aexprs_to_schema, check_input_node, has_aexpr};

#[derive(Default)]
pub struct PredicatePushDown {}
Expand Down Expand Up @@ -91,12 +91,11 @@ impl PredicatePushDown {
let input_schema = lp_arena.get(node).schema(lp_arena);
let mut pushdown_predicates =
optimizer::init_hashmap(Some(acc_predicates.len()));
for (name, &predicate) in acc_predicates.iter() {
for (_, &predicate) in acc_predicates.iter() {
// we can pushdown the predicate
if check_input_node(predicate, &input_schema, expr_arena) {
insert_and_combine_predicate(
&mut pushdown_predicates,
name.clone(),
predicate,
expr_arena,
)
Expand Down Expand Up @@ -157,7 +156,7 @@ impl PredicatePushDown {
///
/// * `AlogicalPlan` - Arena based logical plan tree representing the query.
/// * `acc_predicates` - The predicates we accumulate during tree traversal.
/// The hashmap maps from root-column name to predicates on that column.
/// The hashmap maps from leaf-column name to predicates on that column.
/// If the key is already taken we combine the predicate with a bitand operation.
/// The `Node`s are indexes in the `expr_arena`
/// * `lp_arena` - The local memory arena for the logical plan.
Expand All @@ -178,8 +177,7 @@ impl PredicatePushDown {
// we remove it and apply it locally
let local_predicates = transfer_to_local_by_node(&mut acc_predicates, |node| predicate_is_pushdown_boundary(node, expr_arena));

let name = roots_to_key(&aexpr_to_leaf_names(predicate, expr_arena));
insert_and_combine_predicate(&mut acc_predicates, name, predicate, expr_arena);
insert_and_combine_predicate(&mut acc_predicates, predicate, expr_arena);
let alp = lp_arena.take(input);
let new_input = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?;

Expand Down Expand Up @@ -451,12 +449,9 @@ impl PredicatePushDown {
// be influenced by join
#[allow(clippy::suspicious_else_formatting)]
if !predicate_is_pushdown_boundary(predicate, expr_arena) {
// no else if. predicate can be in both tables.
if check_input_node(predicate, &schema_left, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, &schema_left);
insert_and_combine_predicate(
&mut pushdown_left,
name,
predicate,
expr_arena,
);
Expand All @@ -467,10 +462,8 @@ impl PredicatePushDown {
// in that case we should not push down as the user wants to filter on `x`
// not on `x_rhs`.
else if check_input_node(predicate, &schema_right, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, &schema_right);
insert_and_combine_predicate(
&mut pushdown_right,
name,
predicate,
expr_arena,
);
Expand Down Expand Up @@ -577,12 +570,7 @@ mod test {

let predicate_expr = col("foo").gt(col("bar"));
let predicate = to_aexpr(predicate_expr.clone(), &mut expr_arena);
insert_and_combine_predicate(
&mut acc_predicates,
Arc::from("foo"),
predicate,
&mut expr_arena,
);
insert_and_combine_predicate(&mut acc_predicates, predicate, &mut expr_arena);
let root = *acc_predicates.get("foo").unwrap();
let expr = node_to_expr(root, &expr_arena);
assert_eq!(format!("{:?}", &expr), format!("{:?}", predicate_expr));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ impl Dsl for Node {
/// Don't overwrite predicates but combine them.
pub(super) fn insert_and_combine_predicate(
acc_predicates: &mut PlHashMap<Arc<str>, Node>,
name: Arc<str>,
predicate: Node,
arena: &mut Arena<AExpr>,
) {
let name = predicate_to_key(predicate, arena);

acc_predicates
.entry(name)
.and_modify(|existing_predicate| {
Expand Down Expand Up @@ -77,37 +78,30 @@ pub(super) fn predicate_at_scan(
// an invisible ascii token we use as delimiter
const HIDDEN_DELIMITER: char = '\u{1D17A}';

/// Determine the hashmap key by combining all the root column names of a predicate
pub(super) fn roots_to_key(roots: &[Arc<str>]) -> Arc<str> {
if roots.len() == 1 {
roots[0].clone()
} else {
let mut new = String::with_capacity(32 * roots.len());
for (i, name) in roots.iter().enumerate() {
if i > 0 {
new.push(HIDDEN_DELIMITER)
/// Determine the hashmap key by combining all the leaf column names of a predicate
pub(super) fn predicate_to_key(predicate: Node, expr_arena: &Arena<AExpr>) -> Arc<str> {
let mut iter = aexpr_to_leaf_names_iter(predicate, expr_arena);
if let Some(first) = iter.next() {
if let Some(second) = iter.next() {
let mut new = String::with_capacity(32 * iter.size_hint().0);
new.push_str(&first);
new.push(HIDDEN_DELIMITER);
new.push_str(&second);

for name in iter {
new.push(HIDDEN_DELIMITER);
new.push_str(&name);
}
new.push_str(name);
return Arc::from(new);
}
Arc::from(new)
first
} else {
let mut s = String::new();
s.push(HIDDEN_DELIMITER);
Arc::from(s)
}
}

pub(super) fn get_insertion_name(
expr_arena: &Arena<AExpr>,
predicate: Node,
schema: &Schema,
) -> Arc<str> {
Arc::from(
expr_arena
.get(predicate)
.to_field(schema, Context::Default, expr_arena)
.unwrap()
.name()
.as_ref(),
)
}

// this checks if a predicate from a node upstream can pass
// the predicate in this filter
// Cases where this cannot be the case:
Expand Down Expand Up @@ -255,12 +249,7 @@ where
projection_roots[0].clone(),
);

insert_and_combine_predicate(
acc_predicates,
projection_roots[0].clone(),
predicate,
expr_arena,
);
insert_and_combine_predicate(acc_predicates, predicate, expr_arena);
} else {
// this may be a complex binary function. The predicate may only be valid
// on this projected column so we do filter locally.
Expand Down

0 comments on commit 0e6bff2

Please sign in to comment.