Skip to content

Commit

Permalink
feat: Improve GroupsProxy/GroupsPosition to be sliceable and cheapl…
Browse files Browse the repository at this point in the history
…y cloneable (#20673)
  • Loading branch information
ritchie46 authored Jan 12, 2025
1 parent c82fdd4 commit 7bca692
Show file tree
Hide file tree
Showing 65 changed files with 677 additions and 615 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/object/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef {

// SAFETY: 0..len is in bounds
let list_s = unsafe {
s.agg_list(&GroupsProxy::Slice {
s.agg_list(&GroupsType::Slice {
groups: vec![[0, s.len() as IdxSize]],
rolling: false,
})
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/object/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ mod test {
let ca = ObjectChunked::new(PlSmallStr::EMPTY, values);

let groups =
GroupsProxy::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into());
GroupsType::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into());
let out = unsafe { ca.agg_list(&groups) };
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.len(), groups.len());
Expand All @@ -248,7 +248,7 @@ mod test {
let ca = ObjectChunked::new(PlSmallStr::EMPTY, values);

let groups = vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into();
let out = unsafe { ca.agg_list(&GroupsProxy::Idx(groups)) };
let out = unsafe { ca.agg_list(&GroupsType::Idx(groups)) };
let a = out.explode().unwrap();

let ca_foo = a.as_any().downcast_ref::<ObjectChunked<Foo>>().unwrap();
Expand Down
15 changes: 8 additions & 7 deletions crates/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::hash::Hash;
use std::ops::Deref;

use arrow::bitmap::MutableBitmap;
use polars_compute::unique::BooleanUniqueKernelState;
Expand All @@ -25,21 +26,21 @@ fn finish_is_unique_helper(
}

pub(crate) fn is_unique_helper(
groups: GroupsProxy,
groups: &GroupPositions,
len: IdxSize,
unique_val: bool,
duplicated_val: bool,
) -> BooleanChunked {
debug_assert_ne!(unique_val, duplicated_val);

let idx = match groups {
GroupsProxy::Idx(groups) => groups
.into_iter()
let idx = match groups.deref() {
GroupsType::Idx(groups) => groups
.iter()
.filter_map(|(first, g)| if g.len() == 1 { Some(first) } else { None })
.collect::<Vec<_>>(),
GroupsProxy::Slice { groups, .. } => groups
.into_iter()
.filter_map(|[first, len]| if len == 1 { Some(first) } else { None })
GroupsType::Slice { groups, .. } => groups
.iter()
.filter_map(|[first, len]| if *len == 1 { Some(*first) } else { None })
.collect(),
};
finish_is_unique_helper(idx, len, unique_val, duplicated_val)
Expand Down
38 changes: 19 additions & 19 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ impl Column {
#[cfg(any(feature = "algorithm_group_by", feature = "bitwise"))]
fn agg_with_unit_scalar(
&self,
groups: &GroupsProxy,
series_agg: impl Fn(&Series, &GroupsProxy) -> Series,
groups: &GroupsType,
series_agg: impl Fn(&Series, &GroupsType) -> Series,
) -> Column {
match self {
Column::Series(s) => series_agg(s, groups).into_column(),
Expand All @@ -625,7 +625,7 @@ impl Column {
// 2. whether this aggregation is even defined
let series_aggregation = series_agg(
&s.as_single_value_series(),
&GroupsProxy::Slice {
&GroupsType::Slice {
// @NOTE: this group is always valid since s is non-empty.
groups: vec![[0, 1]],
rolling: false,
Expand Down Expand Up @@ -682,31 +682,31 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_min(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_min(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_min(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_max(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_max(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_max(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_mean(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_sum(&self, groups: &GroupsType) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_sum(groups) }.into()
}
Expand All @@ -715,23 +715,23 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_first(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_first(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_last(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_last(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_n_unique(groups) }.into()
}
Expand All @@ -742,7 +742,7 @@ impl Column {
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_quantile(
&self,
groups: &GroupsProxy,
groups: &GroupsType,
quantile: f64,
method: QuantileMethod,
) -> Self {
Expand All @@ -758,15 +758,15 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_median(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_median(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Self {
pub unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_var(groups, ddof) }.into()
}
Expand All @@ -775,7 +775,7 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self {
pub unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_std(groups, ddof) }.into()
}
Expand All @@ -784,7 +784,7 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_list(&self, groups: &GroupsProxy) -> Self {
pub unsafe fn agg_list(&self, groups: &GroupsType) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_list(groups) }.into()
}
Expand All @@ -793,7 +793,7 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Self {
pub fn agg_valid_count(&self, groups: &GroupsType) -> Self {
// @partition-opt
// @scalar-opt
unsafe { self.as_materialized_series().agg_valid_count(groups) }.into()
Expand All @@ -803,21 +803,21 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_and(&self, groups: &GroupsProxy) -> Self {
pub fn agg_and(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_and(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_or(&self, groups: &GroupsProxy) -> Self {
pub fn agg_or(&self, groups: &GroupsType) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_or(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_xor(&self, groups: &GroupsProxy) -> Self {
pub fn agg_xor(&self, groups: &GroupsType) -> Self {
// @partition-opt
// @scalar-opt
unsafe { self.as_materialized_series().agg_xor(groups) }.into()
Expand Down
30 changes: 15 additions & 15 deletions crates/polars-core/src/frame/group_by/aggregations/agg_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ pub trait AggList {
/// # Safety
///
/// groups should be in bounds
unsafe fn agg_list(&self, _groups: &GroupsProxy) -> Series;
unsafe fn agg_list(&self, _groups: &GroupsType) -> Series;
}

impl<T> AggList for ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
let ca = self.rechunk();

match groups {
GroupsProxy::Idx(groups) => {
GroupsType::Idx(groups) => {
let mut can_fast_explode = true;

let arr = ca.downcast_iter().next().unwrap();
Expand Down Expand Up @@ -92,7 +92,7 @@ where
}
ca.into()
},
GroupsProxy::Slice { groups, .. } => {
GroupsType::Slice { groups, .. } => {
let mut can_fast_explode = true;
let arr = ca.downcast_iter().next().unwrap();
let values = arr.values();
Expand Down Expand Up @@ -159,16 +159,16 @@ where
}

impl AggList for NullChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
match groups {
GroupsProxy::Idx(groups) => {
GroupsType::Idx(groups) => {
let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len());
for idx in groups.all().iter() {
builder.append_with_len(idx.len());
}
builder.finish().into_series()
},
GroupsProxy::Slice { groups, .. } => {
GroupsType::Slice { groups, .. } => {
let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len());
for [_, len] in groups {
builder.append_with_len(*len as usize);
Expand All @@ -180,39 +180,39 @@ impl AggList for NullChunked {
}

impl AggList for BooleanChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
agg_list_by_gather_and_offsets(self, groups)
}
}

impl AggList for StringChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
agg_list_by_gather_and_offsets(self, groups)
}
}

impl AggList for BinaryChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
agg_list_by_gather_and_offsets(self, groups)
}
}

impl AggList for ListChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
agg_list_by_gather_and_offsets(self, groups)
}
}

#[cfg(feature = "dtype-array")]
impl AggList for ArrayChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
agg_list_by_gather_and_offsets(self, groups)
}
}

#[cfg(feature = "object")]
impl<T: PolarsObject> AggList for ObjectChunked<T> {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
let mut can_fast_explode = true;
let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);
let mut length_so_far = 0i64;
Expand Down Expand Up @@ -279,7 +279,7 @@ impl<T: PolarsObject> AggList for ObjectChunked<T> {

#[cfg(feature = "dtype-struct")]
impl AggList for StructChunked {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
let ca = self.clone();
let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(self.len());

Expand Down Expand Up @@ -308,7 +308,7 @@ impl AggList for StructChunked {

unsafe fn agg_list_by_gather_and_offsets<T: PolarsDataType>(
ca: &ChunkedArray<T>,
groups: &GroupsProxy,
groups: &GroupsType,
) -> Series
where
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
Expand Down
Loading

0 comments on commit 7bca692

Please sign in to comment.