Skip to content

Commit

Permalink
fix: Fix CSE case where upper plan has no projection
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 2, 2024
1 parent 51f507f commit ac852e4
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions crates/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub(crate) fn concat_impl<L: AsRef<[LazyFrame]>>(
else {
unreachable!()
};
// TODO! Make this properly lazy.
let mut schema = inputs[0].compute_schema()?.as_ref().clone();

let mut changed = false;
Expand Down
26 changes: 26 additions & 0 deletions crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::BTreeSet;
use std::ops::Add;

use super::*;

Expand Down Expand Up @@ -351,3 +352,28 @@ fn test_cse_prune_scan_filter_difference() -> PolarsResult<()> {

Ok(())
}

#[test]
fn test_cse_union_filter() {
let q = df![
"x" => [0],
"y" => [1]
]
.unwrap()
.lazy();

let q = concat(
[
q.clone().with_columns([col("y").add(lit(0))]),
q.with_columns([col("y").add(lit(1))]),
],
Default::default(),
)
.unwrap()
.filter(col("x").eq(lit(0)));

println!("{}", q.clone().explain(true).unwrap());
let out = q.collect().unwrap();

dbg!(out);
}
8 changes: 5 additions & 3 deletions crates/polars-plan/src/logical_plan/optimizer/cache_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ fn get_upper_projections(
lp_arena: &Arena<IR>,
expr_arena: &Arena<AExpr>,
names_scratch: &mut Vec<ColumnName>,
found_required_columns: &mut bool,
) -> bool {
let parent = lp_arena.get(parent);

Expand All @@ -16,6 +17,7 @@ fn get_upper_projections(
SimpleProjection { columns, .. } => {
let iter = columns.iter_names().map(|s| ColumnName::from(s.as_str()));
names_scratch.extend(iter);
*found_required_columns = true;
false
},
Filter { predicate, .. } => {
Expand Down Expand Up @@ -201,17 +203,17 @@ pub(super) fn set_cache_states(
v.parents.push(frame.parent);
v.cache_nodes.push(frame.current);

let mut found_columns = false;
let mut found_required_columns = false;

for parent_node in frame.parent.into_iter().flatten() {
let keep_going = get_upper_projections(
parent_node,
lp_arena,
expr_arena,
&mut names_scratch,
&mut found_required_columns,
);
if !names_scratch.is_empty() {
found_columns = true;
v.names_union.extend(names_scratch.drain(..));
}
// We stop early as we want to find the first projection node above the cache.
Expand Down Expand Up @@ -241,7 +243,7 @@ pub(super) fn set_cache_states(

// There was no explicit projection and we must take
// all columns
if !found_columns {
if !found_required_columns {
let schema = lp.schema(lp_arena);
v.names_union.extend(
schema
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use super::hive::HivePartitions;
use crate::prelude::*;

impl DslPlan {
// Warning! This should not be used on the DSL internally.
// All schema resolving should be done during conversion to [`IR`].

/// Compute the schema. This requires conversion to [`IR`] and type-resolving.
pub fn compute_schema(&self) -> PolarsResult<SchemaRef> {
let opt_state = OptState {
eager: true,
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,12 @@ def test_cse_drop_nulls_15795() -> None:
C = A.join(B, on="X").select("X")
D = B.select("X")
assert C.join(D, on="X").collect().shape == (1, 1)


def test_cse_no_projection_15980() -> None:
df = pl.LazyFrame({"x": "a", "y": 1})
df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2))

assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict(
as_series=False
) == {"x": ["a", "a"]}

0 comments on commit ac852e4

Please sign in to comment.