Skip to content

Commit

Permalink
fix(rust, python): fix predicate pushdown key check (pola-rs#6577)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and vincent committed Jan 30, 2023
1 parent b59d2b5 commit 9cecbfe
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl PredicatePushDown {
// we should not pass these projections
if exprs
.iter()
.any(|e_n| projection_is_dependent_on_predicate_location(*e_n, expr_arena))
.any(|e_n| projection_is_definite_pushdown_boundary(*e_n, expr_arena))
{
return self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ pub(super) fn predicate_to_key(predicate: Node, expr_arena: &Arena<AExpr>) -> Ar
}
}

fn key_has_name(key: &str, name: &str) -> bool {
if key.contains(HIDDEN_DELIMITER) {
for root_name in key.split(HIDDEN_DELIMITER) {
if root_name == name {
return true;
}
}
}
key == name
}

// this checks if a predicate from a node upstream can pass
// the predicate in this filter
// Cases where this cannot be the case:
Expand Down Expand Up @@ -134,7 +145,7 @@ pub(super) fn predicate_is_pushdown_boundary(node: Node, expr_arena: &Arena<AExp
/// predicate pushdown blocker.
///
/// This checks the boundary of other columns
pub(super) fn projection_is_dependent_on_predicate_location(
pub(super) fn projection_is_definite_pushdown_boundary(
node: Node,
expr_arena: &Arena<AExpr>,
) -> bool {
Expand All @@ -157,22 +168,23 @@ pub(super) fn projection_is_dependent_on_predicate_location(
| Window {..}
| Literal(LiteralValue::Range {..}) => true,
Literal(LiteralValue::Series(s)) => s.len() > 1,
#[cfg(all(feature = "strings", feature = "temporal"))]
// strptime is a cast
Function {function: FunctionExpr::StringExpr(StringFunction::Strptime(_)), .. } => true,
_ => false
}
};
has_aexpr(node, expr_arena, matches)
}

/// This is only a boundary if a predicate refers to the projection output name.
/// This checks the boundary of same columns.
/// So that means columns that are referred in the predicate
/// for instance `predicate = col(A) == col(B).`
/// and `col().some_func().alias(B)` is projected.
/// then the projection can not pass, as column `B` maybe
/// changed by `some_func`
pub(super) fn predicate_is_dependent_on_projection(node: Node, expr_arena: &Arena<AExpr>) -> bool {
pub(super) fn projection_is_optional_pushdown_boundary(
node: Node,
expr_arena: &Arena<AExpr>,
) -> bool {
let matches = |e: &AExpr| {
use AExpr::*;
match e {
Expand All @@ -198,6 +210,63 @@ pub(super) fn predicate_is_dependent_on_projection(node: Node, expr_arena: &Aren
has_aexpr(node, expr_arena, matches)
}

enum LoopBehavior {
Continue,
Nothing,
}

fn rename_predicate_columns_due_to_aliased_projection(
expr_arena: &mut Arena<AExpr>,
acc_predicates: &mut PlHashMap<Arc<str>, Node>,
projection_node: Node,
projection_maybe_boundary: bool,
local_predicates: &mut Vec<Node>,
) -> LoopBehavior {
let projection_aexpr = expr_arena.get(projection_node);
if let AExpr::Alias(_, alias_name) = projection_aexpr {
let alias_name = alias_name.as_ref();
let projection_roots = aexpr_to_leaf_names(projection_node, expr_arena);
// if this alias refers to one of the predicates in the upper nodes
// we rename the column of the predicate before we push it downwards.
if let Some(predicate) = acc_predicates.remove(alias_name) {
if projection_maybe_boundary {
local_predicates.push(predicate);
return LoopBehavior::Continue;
}
if projection_roots.len() == 1 {
// we were able to rename the alias column with the root column name
// before pushing down the predicate
let predicate =
rename_aexpr_leaf_names(predicate, expr_arena, projection_roots[0].clone());

insert_and_combine_predicate(acc_predicates, predicate, expr_arena);
} else {
// this may be a complex binary function. The predicate may only be valid
// on this projected column so we do filter locally.
local_predicates.push(predicate)
}
} else {
// we could not find the alias name
// that could still mean that a predicate that is a complicated binary expression
// refers to the aliased name. If we find it, we remove it for now
// TODO! rename the expression.
let mut remove_names = vec![];
for (composed_name, _) in acc_predicates.iter() {
if key_has_name(composed_name, alias_name) {
remove_names.push(composed_name.clone());
break;
}
}

for composed_name in remove_names {
let predicate = acc_predicates.remove(&composed_name).unwrap();
local_predicates.push(predicate)
}
}
}
LoopBehavior::Nothing
}

/// Implementation for both Hstack and Projection
pub(super) fn rewrite_projection_node(
expr_arena: &mut Arena<AExpr>,
Expand All @@ -209,77 +278,38 @@ pub(super) fn rewrite_projection_node(
where
{
let mut local_predicates = Vec::with_capacity(acc_predicates.len());
let input_schema = lp_arena.get(input).schema(lp_arena);

// maybe update predicate name if a projection is an alias
// aliases change the column names and because we push the predicates downwards
// this may be problematic as the aliased column may not yet exist.
for projection_node in &projections {
// only if a predicate refers to this projection's output column.
let projection_maybe_boundary =
predicate_is_dependent_on_projection(*projection_node, expr_arena);
projection_is_optional_pushdown_boundary(*projection_node, expr_arena);

{
// if this alias refers to one of the predicates in the upper nodes
// we rename the column of the predicate before we push it downwards.
match rename_predicate_columns_due_to_aliased_projection(
expr_arena,
acc_predicates,
*projection_node,
projection_maybe_boundary,
&mut local_predicates,
) {
LoopBehavior::Continue => continue,
LoopBehavior::Nothing => {}
}
}
let input_schema = lp_arena.get(input).schema(lp_arena);
let projection_expr = expr_arena.get(*projection_node);
let output_field = projection_expr
.to_field(&input_schema, Context::Default, expr_arena)
.unwrap();
let projection_roots = aexpr_to_leaf_names(*projection_node, expr_arena);

{
let projection_aexpr = expr_arena.get(*projection_node);
if let AExpr::Alias(_, alias_name) = projection_aexpr {
// if this alias refers to one of the predicates in the upper nodes
// we rename the column of the predicate before we push it downwards.

if let Some(predicate) = acc_predicates.remove(alias_name) {
if projection_maybe_boundary {
local_predicates.push(predicate);
continue;
}
if projection_roots.len() == 1 {
// we were able to rename the alias column with the root column name
// before pushing down the predicate
let predicate = rename_aexpr_leaf_names(
predicate,
expr_arena,
projection_roots[0].clone(),
);

insert_and_combine_predicate(acc_predicates, predicate, expr_arena);
} else {
// this may be a complex binary function. The predicate may only be valid
// on this projected column so we do filter locally.
local_predicates.push(predicate)
}
} else {
// we could not find the alias name
// that could still mean that a predicate that is a complicated binary expression
// refers to the aliased name. If we find it, we remove it for now
// TODO! rename the expression.
let mut remove_names = vec![];
for (composed_name, _) in acc_predicates.iter() {
if composed_name.contains(HIDDEN_DELIMITER) {
for root_name in composed_name.as_ref().split(HIDDEN_DELIMITER) {
if root_name == alias_name.as_ref() {
remove_names.push(composed_name.clone());
break;
}
}
}
}

for composed_name in remove_names {
let predicate = acc_predicates.remove(&composed_name).unwrap();
local_predicates.push(predicate)
}
}
}
}

// we check if predicates can be done on the input above
// this can only be done if the current projection is not a projection boundary
let is_boundary =
projection_is_dependent_on_predicate_location(*projection_node, expr_arena);
let is_boundary = projection_is_definite_pushdown_boundary(*projection_node, expr_arena);

// remove predicates that cannot be done on the input above
let to_local = acc_predicates
Expand All @@ -295,7 +325,7 @@ where
// checks 1.
if check_input_node(*predicate, &input_schema, expr_arena)
// checks 2.
&& !(output_field.name().as_str() == &**name && projection_maybe_boundary)
&& !(key_has_name(name, output_field.name()) && projection_maybe_boundary)
// checks 3.
&& !is_boundary
{
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ def test_predicate_strptime_6558() -> None:
.filter((pl.col("date").dt.year() == 2022) & (pl.col("date").dt.month() == 1))
.collect()
).to_dict(False) == {"date": [date(2022, 1, 3)]}


def test_predicate_arr_first_6573() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6],
"b": [6, 5, 4, 3, 2, 1],
}
)

assert (
df.lazy()
.with_columns(pl.col("a").list())
.with_columns(pl.col("a").arr.first())
.filter(pl.col("a") == pl.col("b"))
.collect()
).to_dict(False) == {"a": [1], "b": [1]}

0 comments on commit 9cecbfe

Please sign in to comment.