Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Check groups in group-by filter #18300

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ impl ApplyExpr {
// then unpack the lists and finally create iterators from this list chunked arrays.
let mut iters = acs
.iter_mut()
.map(|ac| {
// SAFETY: unstable series never lives longer than the iterator.
unsafe { ac.iter_groups(self.pass_name_to_apply) }
})
.map(|ac| ac.iter_groups(self.pass_name_to_apply))
.collect::<Vec<_>>();

// Length of the items to iterate over.
Expand Down
16 changes: 7 additions & 9 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,13 @@ impl BinaryExpr {
mut ac_r: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
let name = ac_l.series().name().to_string();
// SAFETY: unstable series never lives longer than the iterator.
let ca = unsafe {
ac_l.iter_groups(false)
.zip(ac_r.iter_groups(false))
.map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op)))
.map(|opt_res| opt_res.transpose())
.collect::<PolarsResult<ListChunked>>()?
.with_name(&name)
};
let ca = ac_l
.iter_groups(false)
.zip(ac_r.iter_groups(false))
.map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op)))
.map(|opt_res| opt_res.transpose())
.collect::<PolarsResult<ListChunked>>()?
.with_name(&name);

ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
ac_l.with_agg_state(AggState::AggregatedList(ca.into_series()));
Expand Down
9 changes: 7 additions & 2 deletions crates/polars-expr/src/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ impl PhysicalExpr for FilterExpr {

let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f));
let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?);
// Check if the groups are still equal, otherwise aggregate.
// TODO! create a special group iters that don't materialize
if ac_s.groups.as_ref() as *const _ != ac_predicate.groups.as_ref() as *const _ {
let _ = ac_s.aggregated();
let _ = ac_predicate.aggregated();
}

if ac_predicate.is_aggregated() || ac_s.is_aggregated() {
// SAFETY: unstable series never lives longer than the iterator.
let preds = unsafe { ac_predicate.iter_groups(false) };
let preds = ac_predicate.iter_groups(false);
let s = ac_s.aggregated();
let ca = s.list()?;
let out = if ca.is_empty() {
Expand Down
26 changes: 12 additions & 14 deletions crates/polars-expr/src/expressions/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,19 @@ impl GatherExpr {
ac.series().name(),
)?;

unsafe {
let iter = ac.iter_groups(false).zip(idx.iter_groups(false));
for (s, idx) in iter {
match (s, idx) {
(Some(s), Some(idx)) => {
let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?;
let out = s.as_ref().take(&idx)?;
builder.append_series(&out)?;
},
_ => builder.append_null(),
};
}
let out = builder.finish().into_series();
ac.with_agg_state(AggState::AggregatedList(out));
let iter = ac.iter_groups(false).zip(idx.iter_groups(false));
for (s, idx) in iter {
match (s, idx) {
(Some(s), Some(idx)) => {
let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?;
let out = s.as_ref().take(&idx)?;
builder.append_series(&out)?;
},
_ => builder.append_null(),
};
}
let out = builder.finish().into_series();
ac.with_agg_state(AggState::AggregatedList(out));
Ok(ac)
}
}
5 changes: 1 addition & 4 deletions crates/polars-expr/src/expressions/group_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use polars_core::series::amortized_iter::AmortSeries;
use super::*;

impl<'a> AggregationContext<'a> {
/// # Safety
/// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub(super) unsafe fn iter_groups(
pub(super) fn iter_groups(
&mut self,
keep_names: bool,
) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ impl<'a> AggregationContext<'a> {
self.groups();
let rows = self.groups.len();
let s = s.new_from_index(0, rows);
s.reshape_list(&[rows as i64, -1]).unwrap()
let out = s.reshape_list(&[rows as i64, -1]).unwrap();
self.state = AggState::AggregatedList(out.clone());
out
},
}
}
Expand Down
37 changes: 17 additions & 20 deletions crates/polars-expr/src/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,23 @@ fn finish_as_iters<'a>(
mut ac_falsy: AggregationContext<'a>,
mut ac_mask: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
// SAFETY: unstable series never lives longer than the iterator.
let ca = unsafe {
ac_truthy
.iter_groups(false)
.zip(ac_falsy.iter_groups(false))
.zip(ac_mask.iter_groups(false))
.map(|((truthy, falsy), mask)| {
match (truthy, falsy, mask) {
(Some(truthy), Some(falsy), Some(mask)) => Some(
truthy
.as_ref()
.zip_with(mask.as_ref().bool()?, falsy.as_ref()),
),
_ => None,
}
.transpose()
})
.collect::<PolarsResult<ListChunked>>()?
.with_name(ac_truthy.series().name())
};
let ca = ac_truthy
.iter_groups(false)
.zip(ac_falsy.iter_groups(false))
.zip(ac_mask.iter_groups(false))
.map(|((truthy, falsy), mask)| {
match (truthy, falsy, mask) {
(Some(truthy), Some(falsy), Some(mask)) => Some(
truthy
.as_ref()
.zip_with(mask.as_ref().bool()?, falsy.as_ref()),
),
_ => None,
}
.transpose()
})
.collect::<PolarsResult<ListChunked>>()?
.with_name(ac_truthy.series().name());

// Aggregation leaves only a single chunk.
let arr = ca.downcast_iter().next().unwrap();
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/operations/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,16 @@ def test_filter_group_aware_17030() -> None:
(group_count > 2) & (group_cum_count > 1) & (group_cum_count < group_count)
)
assert df.filter(filter_expr)["foo"].to_list() == ["1", "2"]


def test_invalid_filter_18295() -> None:
codes = ["a"] * 5 + ["b"] * 5
values = list(range(-2, 3)) + list(range(2, -3, -1))
df = pl.DataFrame({"code": codes, "value": values})
with pytest.raises(pl.exceptions.ShapeError):
df.group_by("code").agg(
pl.col("value")
.ewm_mean(span=2, ignore_nulls=True)
.tail(3)
.filter(pl.col("value") > 0),
).sort("code")