Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Improve unique pred-pd #20569

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 2 additions & 48 deletions crates/polars-plan/src/plans/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod hash;
mod scalar;
mod schema;
mod traverse;
mod utils;

use std::hash::{Hash, Hasher};

Expand All @@ -18,8 +17,8 @@ pub use scalar::is_scalar_ae;
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
pub use traverse::*;
pub(crate) use utils::permits_filter_pushdown;
pub use utils::*;
mod properties;
pub use properties::*;

use crate::constants::LEN;
use crate::plans::Context;
Expand Down Expand Up @@ -212,43 +211,6 @@ impl AExpr {
AExpr::Column(name)
}

/// Checks whether this expression is elementwise. This only checks the top level expression.
pub(crate) fn is_elementwise_top_level(&self) -> bool {
use AExpr::*;

match self {
AnonymousFunction { options, .. } => options.is_elementwise(),

// Non-strict strptime must be done in-memory to ensure the format
// is consistent across the entire dataframe.
#[cfg(all(feature = "strings", feature = "temporal"))]
Function {
options,
function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
..
} => {
assert!(options.is_elementwise());
opts.strict
},

Function { options, .. } => options.is_elementwise(),

Literal(v) => v.projects_as_scalar(),

Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,

Agg { .. }
| Explode(_)
| Filter { .. }
| Gather { .. }
| Len
| Slice { .. }
| Sort { .. }
| SortBy { .. }
| Window { .. } => false,
}
}

/// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.
pub fn get_type(
&self,
Expand All @@ -259,12 +221,4 @@ impl AExpr {
self.to_field(schema, ctxt, arena)
.map(|f| f.dtype().clone())
}

pub(crate) fn is_leaf(&self) -> bool {
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
}

pub(crate) fn is_col(&self) -> bool {
matches!(self, AExpr::Column(_))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,53 @@ use polars_utils::unitvec;

use super::*;

impl AExpr {
pub(crate) fn is_leaf(&self) -> bool {
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
}

pub(crate) fn is_col(&self) -> bool {
matches!(self, AExpr::Column(_))
}

/// Checks whether this expression is elementwise. This only checks the top level expression.
pub(crate) fn is_elementwise_top_level(&self) -> bool {
use AExpr::*;

match self {
AnonymousFunction { options, .. } => options.is_elementwise(),

// Non-strict strptime must be done in-memory to ensure the format
// is consistent across the entire dataframe.
#[cfg(all(feature = "strings", feature = "temporal"))]
Function {
options,
function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
..
} => {
assert!(options.is_elementwise());
opts.strict
},

Function { options, .. } => options.is_elementwise(),

Literal(v) => v.projects_as_scalar(),

Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,

Agg { .. }
| Explode(_)
| Filter { .. }
| Gather { .. }
| Len
| Slice { .. }
| Sort { .. }
| SortBy { .. }
| Window { .. } => false,
}
}
}

/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will
/// be extended further with any nested expression nodes.
pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ pub(super) fn process_group_by(
for (pred_name, predicate) in acc_predicates {
// Counts change due to groupby's
// TODO! handle aliases, so that the predicate that is pushed down refers to the column before alias.
let mut push_down = !has_aexpr(predicate.node(), expr_arena, |ae| {
matches!(ae, AExpr::Len | AExpr::Alias(_, _))
});
let mut push_down = !has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len));

for name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) {
push_down &= key_schema.contains(name.as_ref());
Expand Down
47 changes: 30 additions & 17 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,25 +488,38 @@ impl PredicatePushDown<'_> {
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
},
Distinct { input, options } => {
if let Some(ref subset) = options.subset {
// Predicates on the subset can pass.
let subset = subset.clone();
let mut names_set = PlHashSet::<PlSmallStr>::with_capacity(subset.len());
for name in subset.iter() {
names_set.insert(name.clone());
}

let condition = |name: &PlSmallStr| !names_set.contains(name.as_str());
let local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Distinct { input, options };
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
let subset = if let Some(ref subset) = options.subset {
subset.as_ref()
} else {
let lp = Distinct { input, options };
self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena)
&[]
};
let mut names_set = PlHashSet::<PlSmallStr>::with_capacity(subset.len());
for name in subset.iter() {
names_set.insert(name.clone());
}

let local_predicates = match options.keep_strategy {
UniqueKeepStrategy::Any => {
let condition = |e: &ExprIR| {
let ae = expr_arena.get(e.node());
// if not elementwise -> to local
!is_elementwise_rec(ae, expr_arena)
};
transfer_to_local_by_expr_ir(expr_arena, &mut acc_predicates, condition)
},
UniqueKeepStrategy::First
| UniqueKeepStrategy::Last
| UniqueKeepStrategy::None => {
let condition = |name: &PlSmallStr| {
!subset.is_empty() && !names_set.contains(name.as_str())
};
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition)
},
};

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Distinct { input, options };
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
},
Join {
input_left,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ pub(super) fn predicate_at_scan(
}
}

/// Evaluates a condition on the column name inputs of every predicate, where if
/// the condition evaluates to true on any column name the predicate is
/// transferred to local.
pub(super) fn transfer_to_local_by_expr_ir<F>(
expr_arena: &Arena<AExpr>,
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
mut condition: F,
) -> Vec<ExprIR>
where
F: FnMut(&ExprIR) -> bool,
{
let mut remove_keys = Vec::with_capacity(acc_predicates.len());

for predicate in acc_predicates.values() {
if condition(predicate) {
if let Some(name) = aexpr_to_leaf_names_iter(predicate.node(), expr_arena).next() {
remove_keys.push(name);
}
}
}
let mut local_predicates = Vec::with_capacity(remove_keys.len());
for key in remove_keys {
if let Some(pred) = acc_predicates.remove(&*key) {
local_predicates.push(pred)
}
}
local_predicates
}

/// Evaluates a condition on the column name inputs of every predicate, where if
/// the condition evaluates to true on any column name the predicate is
/// transferred to local.
Expand All @@ -94,7 +123,7 @@ where
let mut remove_keys = Vec::with_capacity(acc_predicates.len());

for (key, predicate) in &*acc_predicates {
let root_names = aexpr_to_leaf_names(predicate.node(), expr_arena);
let root_names = aexpr_to_leaf_names_iter(predicate.node(), expr_arena);
for name in root_names {
if condition(&name) {
remove_keys.push(key.clone());
Expand Down
23 changes: 12 additions & 11 deletions py-polars/tests/unit/operations/unique/test_unique.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import re
from datetime import date
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -47,16 +46,6 @@ def test_unique_predicate_pd() -> None:
.filter(pl.col("x") == "abc")
.filter(pl.col("z"))
)
plan = q.explain()
assert r'FILTER col("z")' in plan
# We can push filters if they only depend on the subset columns of unique()
assert (
re.search(
r"FILTER \[\(col\(\"x\"\)\) == \(String\(abc\)\)\] FROM\n\s*DF",
plan,
)
is not None
)
assert_frame_equal(q.collect(predicate_pushdown=False), q.collect())


Expand Down Expand Up @@ -256,3 +245,15 @@ def test_unique_check_order_20480() -> None:
.item()
== 1
)


def test_predicate_pushdown_unique() -> None:
q = (
pl.LazyFrame({"id": [1, 2, 3]})
.with_columns(pl.date(2024, 1, 1) + pl.duration(days=[1, 2, 3])) # type: ignore[arg-type]
.unique()
)

print(q.filter(pl.col("id").is_in([1, 2, 3])).explain())
assert not q.filter(pl.col("id").is_in([1, 2, 3])).explain().startswith("FILTER")
assert q.filter(pl.col("id").sum() == pl.col("id")).explain().startswith("FILTER")
Loading