diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index a08559a9d14d..188f4a78ab84 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -75,6 +75,7 @@ pub(crate) fn concat_impl>( else { unreachable!() }; + // TODO! Make this properly lazy. let mut schema = inputs[0].compute_schema()?.as_ref().clone(); let mut changed = false; diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index b9e23427cde9..f5c584450d73 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -1,4 +1,5 @@ use std::collections::BTreeSet; +use std::ops::Add; use super::*; @@ -351,3 +352,28 @@ fn test_cse_prune_scan_filter_difference() -> PolarsResult<()> { Ok(()) } + +#[test] +fn test_cse_union_filter() { + let q = df![ + "x" => [0], + "y" => [1] + ] + .unwrap() + .lazy(); + + let q = concat( + [ + q.clone().with_columns([col("y").add(lit(0))]), + q.with_columns([col("y").add(lit(1))]), + ], + Default::default(), + ) + .unwrap() + .filter(col("x").eq(lit(0))); + + println!("{}", q.clone().explain(true).unwrap()); + let out = q.collect().unwrap(); + + dbg!(out); +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index bcc5c243c672..353807ee9095 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -7,6 +7,7 @@ fn get_upper_projections( lp_arena: &Arena, expr_arena: &Arena, names_scratch: &mut Vec, + found_required_columns: &mut bool, ) -> bool { let parent = lp_arena.get(parent); @@ -16,6 +17,7 @@ fn get_upper_projections( SimpleProjection { columns, .. } => { let iter = columns.iter_names().map(|s| ColumnName::from(s.as_str())); names_scratch.extend(iter); + *found_required_columns = true; false }, Filter { predicate, .. } => { @@ -201,7 +203,7 @@ pub(super) fn set_cache_states( v.parents.push(frame.parent); v.cache_nodes.push(frame.current); - let mut found_columns = false; + let mut found_required_columns = false; for parent_node in frame.parent.into_iter().flatten() { let keep_going = get_upper_projections( @@ -209,9 +211,9 @@ pub(super) fn set_cache_states( lp_arena, expr_arena, &mut names_scratch, + &mut found_required_columns, ); if !names_scratch.is_empty() { - found_columns = true; v.names_union.extend(names_scratch.drain(..)); } // We stop early as we want to find the first projection node above the cache. @@ -241,7 +243,7 @@ pub(super) fn set_cache_states( // There was no explicit projection and we must take // all columns - if !found_columns { + if !found_required_columns { let schema = lp.schema(lp_arena); v.names_union.extend( schema diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index 7d7044e498e1..6c4629a80cb0 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -12,6 +12,10 @@ use super::hive::HivePartitions; use crate::prelude::*; impl DslPlan { + // Warning! This should not be used on the DSL internally. + // All schema resolving should be done during conversion to [`IR`]. + + /// Compute the schema. This requires conversion to [`IR`] and type-resolving. pub fn compute_schema(&self) -> PolarsResult { let opt_state = OptState { eager: true, diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index d400b8127f96..6ce6cfc4621e 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -723,3 +723,12 @@ def test_cse_drop_nulls_15795() -> None: C = A.join(B, on="X").select("X") D = B.select("X") assert C.join(D, on="X").collect().shape == (1, 1) + + +def test_cse_no_projection_15980() -> None: + df = pl.LazyFrame({"x": "a", "y": 1}) + df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2)) + + assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict( + as_series=False + ) == {"x": ["a", "a"]}