From d2c9fc7dda349cb02bce8f6abb981ff3c9b86c92 Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 21 Aug 2024 13:32:41 +0200 Subject: [PATCH] fix: Check groups in group-by filter --- crates/polars-expr/src/expressions/apply.rs | 5 +-- crates/polars-expr/src/expressions/binary.rs | 16 ++++---- crates/polars-expr/src/expressions/filter.rs | 9 ++++- crates/polars-expr/src/expressions/gather.rs | 26 ++++++------- .../polars-expr/src/expressions/group_iter.rs | 5 +-- crates/polars-expr/src/expressions/mod.rs | 4 +- crates/polars-expr/src/expressions/ternary.rs | 37 +++++++++---------- .../tests/unit/operations/test_filter.py | 13 +++++++ 8 files changed, 61 insertions(+), 54 deletions(-) diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 4d13d784540e..802e130d15f2 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -240,10 +240,7 @@ impl ApplyExpr { // then unpack the lists and finally create iterators from this list chunked arrays. let mut iters = acs .iter_mut() - .map(|ac| { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { ac.iter_groups(self.pass_name_to_apply) } - }) + .map(|ac| ac.iter_groups(self.pass_name_to_apply)) .collect::>(); // Length of the items to iterate over. diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index ce26c1a57c77..55caf00ad69a 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -151,15 +151,13 @@ impl BinaryExpr { mut ac_r: AggregationContext<'a>, ) -> PolarsResult> { let name = ac_l.series().name().to_string(); - // SAFETY: unstable series never lives longer than the iterator. - let ca = unsafe { - ac_l.iter_groups(false) - .zip(ac_r.iter_groups(false)) - .map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op))) - .map(|opt_res| opt_res.transpose()) - .collect::>()? - .with_name(&name) - }; + let ca = ac_l + .iter_groups(false) + .zip(ac_r.iter_groups(false)) + .map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op))) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(&name); ac_l.with_update_groups(UpdateGroups::WithSeriesLen); ac_l.with_agg_state(AggState::AggregatedList(ca.into_series())); diff --git a/crates/polars-expr/src/expressions/filter.rs b/crates/polars-expr/src/expressions/filter.rs index d9df88419ae7..db9ee0cf120e 100644 --- a/crates/polars-expr/src/expressions/filter.rs +++ b/crates/polars-expr/src/expressions/filter.rs @@ -45,10 +45,15 @@ impl PhysicalExpr for FilterExpr { let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f)); let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?); + // Check if the groups are still equal, otherwise aggregate. + // TODO! create a special group iters that don't materialize + if ac_s.groups.as_ref() as *const _ != ac_predicate.groups.as_ref() as *const _ { + let _ = ac_s.aggregated(); + let _ = ac_predicate.aggregated(); + } if ac_predicate.is_aggregated() || ac_s.is_aggregated() { - // SAFETY: unstable series never lives longer than the iterator. - let preds = unsafe { ac_predicate.iter_groups(false) }; + let preds = ac_predicate.iter_groups(false); let s = ac_s.aggregated(); let ca = s.list()?; let out = if ca.is_empty() { diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index c54f8b9e8262..951833717a33 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -253,21 +253,19 @@ impl GatherExpr { ac.series().name(), )?; - unsafe { - let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); - for (s, idx) in iter { - match (s, idx) { - (Some(s), Some(idx)) => { - let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; - let out = s.as_ref().take(&idx)?; - builder.append_series(&out)?; - }, - _ => builder.append_null(), - }; - } - let out = builder.finish().into_series(); - ac.with_agg_state(AggState::AggregatedList(out)); + let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); + for (s, idx) in iter { + match (s, idx) { + (Some(s), Some(idx)) => { + let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; + let out = s.as_ref().take(&idx)?; + builder.append_series(&out)?; + }, + _ => builder.append_null(), + }; } + let out = builder.finish().into_series(); + ac.with_agg_state(AggState::AggregatedList(out)); Ok(ac) } } diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs index 8c921a519bd1..26c68fdae3d2 100644 --- a/crates/polars-expr/src/expressions/group_iter.rs +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -5,10 +5,7 @@ use polars_core::series::amortized_iter::AmortSeries; use super::*; impl<'a> AggregationContext<'a> { - /// # Safety - /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive - /// longer than the iterator is UB. - pub(super) unsafe fn iter_groups( + pub(super) fn iter_groups( &mut self, keep_names: bool, ) -> Box> + '_> { diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 17179f89cbdd..266d577b22ee 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -421,7 +421,9 @@ impl<'a> AggregationContext<'a> { self.groups(); let rows = self.groups.len(); let s = s.new_from_index(0, rows); - s.reshape_list(&[rows as i64, -1]).unwrap() + let out = s.reshape_list(&[rows as i64, -1]).unwrap(); + self.state = AggState::AggregatedList(out.clone()); + out }, } } diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index b84e868efd35..e3c2f9e833a2 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -37,26 +37,23 @@ fn finish_as_iters<'a>( mut ac_falsy: AggregationContext<'a>, mut ac_mask: AggregationContext<'a>, ) -> PolarsResult> { - // SAFETY: unstable series never lives longer than the iterator. - let ca = unsafe { - ac_truthy - .iter_groups(false) - .zip(ac_falsy.iter_groups(false)) - .zip(ac_mask.iter_groups(false)) - .map(|((truthy, falsy), mask)| { - match (truthy, falsy, mask) { - (Some(truthy), Some(falsy), Some(mask)) => Some( - truthy - .as_ref() - .zip_with(mask.as_ref().bool()?, falsy.as_ref()), - ), - _ => None, - } - .transpose() - }) - .collect::>()? - .with_name(ac_truthy.series().name()) - }; + let ca = ac_truthy + .iter_groups(false) + .zip(ac_falsy.iter_groups(false)) + .zip(ac_mask.iter_groups(false)) + .map(|((truthy, falsy), mask)| { + match (truthy, falsy, mask) { + (Some(truthy), Some(falsy), Some(mask)) => Some( + truthy + .as_ref() + .zip_with(mask.as_ref().bool()?, falsy.as_ref()), + ), + _ => None, + } + .transpose() + }) + .collect::>()? + .with_name(ac_truthy.series().name()); // Aggregation leaves only a single chunk. let arr = ca.downcast_iter().next().unwrap(); diff --git a/py-polars/tests/unit/operations/test_filter.py b/py-polars/tests/unit/operations/test_filter.py index df796b44b991..eed550fac516 100644 --- a/py-polars/tests/unit/operations/test_filter.py +++ b/py-polars/tests/unit/operations/test_filter.py @@ -285,3 +285,16 @@ def test_filter_group_aware_17030() -> None: (group_count > 2) & (group_cum_count > 1) & (group_cum_count < group_count) ) assert df.filter(filter_expr)["foo"].to_list() == ["1", "2"] + + +def test_invalid_filter_18295() -> None: + codes = ["a"] * 5 + ["b"] * 5 + values = list(range(-2, 3)) + list(range(2, -3, -1)) + df = pl.DataFrame({"code": codes, "value": values}) + with pytest.raises(pl.exceptions.ShapeError): + df.group_by("code").agg( + pl.col("value") + .ewm_mean(span=2, ignore_nulls=True) + .tail(3) + .filter(pl.col("value") > 0), + ).sort("code")