From 7bca6922f5d0be44bee4b460e01a0d7ca489368e Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 12 Jan 2025 16:40:11 +0100 Subject: [PATCH] feat: Improve `GroupsProxy/GroupsPosition` to be sliceable and cheaply cloneable (#20673) --- .../src/chunked_array/object/builder.rs | 2 +- .../src/chunked_array/object/extension/mod.rs | 4 +- .../src/chunked_array/ops/unique/mod.rs | 15 +- crates/polars-core/src/frame/column/mod.rs | 38 +-- .../frame/group_by/aggregations/agg_list.rs | 30 +- .../frame/group_by/aggregations/boolean.rs | 26 +- .../frame/group_by/aggregations/dispatch.rs | 48 ++- .../src/frame/group_by/aggregations/mod.rs | 90 +++--- .../src/frame/group_by/aggregations/string.rs | 16 +- crates/polars-core/src/frame/group_by/expr.rs | 2 +- .../polars-core/src/frame/group_by/hashing.rs | 16 +- .../src/frame/group_by/into_groups.rs | 48 ++- crates/polars-core/src/frame/group_by/mod.rs | 34 +- .../polars-core/src/frame/group_by/perfect.rs | 8 +- .../frame/group_by/{proxy.rs => position.rs} | 290 ++++++++++-------- crates/polars-core/src/frame/mod.rs | 14 +- .../src/series/implementations/array.rs | 6 +- .../src/series/implementations/binary.rs | 10 +- .../series/implementations/binary_offset.rs | 4 +- .../src/series/implementations/boolean.rs | 22 +- .../src/series/implementations/categorical.rs | 4 +- .../src/series/implementations/date.rs | 8 +- .../src/series/implementations/datetime.rs | 8 +- .../src/series/implementations/decimal.rs | 10 +- .../src/series/implementations/duration.rs | 14 +- .../src/series/implementations/floats.rs | 22 +- .../src/series/implementations/list.rs | 6 +- .../src/series/implementations/mod.rs | 22 +- .../src/series/implementations/null.rs | 8 +- .../src/series/implementations/object.rs | 6 +- .../src/series/implementations/string.rs | 10 +- .../src/series/implementations/struct_.rs | 4 +- .../src/series/implementations/time.rs | 8 +- crates/polars-core/src/series/series_trait.rs | 20 +- .../src/expressions/aggregation.rs | 20 +- crates/polars-expr/src/expressions/alias.rs | 6 +- crates/polars-expr/src/expressions/apply.rs | 8 +- crates/polars-expr/src/expressions/binary.rs | 6 +- crates/polars-expr/src/expressions/cast.rs | 6 +- crates/polars-expr/src/expressions/column.rs | 6 +- crates/polars-expr/src/expressions/count.rs | 6 +- crates/polars-expr/src/expressions/filter.rs | 17 +- crates/polars-expr/src/expressions/gather.rs | 16 +- crates/polars-expr/src/expressions/literal.rs | 6 +- crates/polars-expr/src/expressions/mod.rs | 66 ++-- crates/polars-expr/src/expressions/rolling.rs | 2 +- crates/polars-expr/src/expressions/slice.rs | 47 +-- crates/polars-expr/src/expressions/sort.rs | 12 +- crates/polars-expr/src/expressions/sortby.rs | 14 +- crates/polars-expr/src/expressions/ternary.rs | 6 +- crates/polars-expr/src/expressions/window.rs | 37 +-- .../polars-expr/src/state/execution_state.rs | 4 +- crates/polars-io/src/partition.rs | 6 +- crates/polars-lazy/src/dsl/list.rs | 13 +- crates/polars-lazy/src/frame/pivot.rs | 2 +- crates/polars-lazy/src/tests/aggregations.rs | 2 +- .../src/executors/group_by.rs | 4 +- .../src/executors/group_by_dynamic.rs | 2 +- .../src/executors/group_by_partitioned.rs | 4 +- .../src/executors/group_by_rolling.rs | 8 +- crates/polars-ops/src/chunked_array/mode.rs | 8 +- .../nan_propagating_aggregate.rs | 16 +- .../polars-ops/src/frame/pivot/positioning.rs | 4 +- .../src/executors/sinks/sort/ooc.rs | 4 +- crates/polars-time/src/group_by/dynamic.rs | 61 ++-- 65 files changed, 677 insertions(+), 615 deletions(-) rename crates/polars-core/src/frame/group_by/{proxy.rs => position.rs} (71%) diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index c8fe2bb4a2ff..ef0db377a1a2 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -167,7 +167,7 @@ pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef { // SAFETY: 0..len is in bounds let list_s = unsafe { - s.agg_list(&GroupsProxy::Slice { + s.agg_list(&GroupsType::Slice { groups: vec![[0, s.len() as IdxSize]], rolling: false, }) diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index d1428424969f..e83e4b05e1eb 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -224,7 +224,7 @@ mod test { let ca = ObjectChunked::new(PlSmallStr::EMPTY, values); let groups = - GroupsProxy::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into()); + GroupsType::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into()); let out = unsafe { ca.agg_list(&groups) }; assert!(matches!(out.dtype(), DataType::List(_))); assert_eq!(out.len(), groups.len()); @@ -248,7 +248,7 @@ mod test { let ca = ObjectChunked::new(PlSmallStr::EMPTY, values); let groups = vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into(); - let out = unsafe { ca.agg_list(&GroupsProxy::Idx(groups)) }; + let out = unsafe { ca.agg_list(&GroupsType::Idx(groups)) }; let a = out.explode().unwrap(); let ca_foo = a.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index f8ae3d78cfc7..11b06af190e6 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -1,4 +1,5 @@ use std::hash::Hash; +use std::ops::Deref; use arrow::bitmap::MutableBitmap; use polars_compute::unique::BooleanUniqueKernelState; @@ -25,21 +26,21 @@ fn finish_is_unique_helper( } pub(crate) fn is_unique_helper( - groups: GroupsProxy, + groups: &GroupPositions, len: IdxSize, unique_val: bool, duplicated_val: bool, ) -> BooleanChunked { debug_assert_ne!(unique_val, duplicated_val); - let idx = match groups { - GroupsProxy::Idx(groups) => groups - .into_iter() + let idx = match groups.deref() { + GroupsType::Idx(groups) => groups + .iter() .filter_map(|(first, g)| if g.len() == 1 { Some(first) } else { None }) .collect::>(), - GroupsProxy::Slice { groups, .. } => groups - .into_iter() - .filter_map(|[first, len]| if len == 1 { Some(first) } else { None }) + GroupsType::Slice { groups, .. } => groups + .iter() + .filter_map(|[first, len]| if *len == 1 { Some(*first) } else { None }) .collect(), }; finish_is_unique_helper(idx, len, unique_val, duplicated_val) diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 032b748a74c8..56d91073536b 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -608,8 +608,8 @@ impl Column { #[cfg(any(feature = "algorithm_group_by", feature = "bitwise"))] fn agg_with_unit_scalar( &self, - groups: &GroupsProxy, - series_agg: impl Fn(&Series, &GroupsProxy) -> Series, + groups: &GroupsType, + series_agg: impl Fn(&Series, &GroupsType) -> Series, ) -> Column { match self { Column::Series(s) => series_agg(s, groups).into_column(), @@ -625,7 +625,7 @@ impl Column { // 2. whether this aggregation is even defined let series_aggregation = series_agg( &s.as_single_value_series(), - &GroupsProxy::Slice { + &GroupsType::Slice { // @NOTE: this group is always valid since s is non-empty. groups: vec![[0, 1]], rolling: false, @@ -682,7 +682,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_min(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_min(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_min(g) }) } @@ -690,7 +690,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_max(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_max(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_max(g) }) } @@ -698,7 +698,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_mean(g) }) } @@ -706,7 +706,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_sum(&self, groups: &GroupsType) -> Self { // @scalar-opt unsafe { self.as_materialized_series().agg_sum(groups) }.into() } @@ -715,7 +715,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_first(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_first(g) }) } @@ -723,7 +723,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_last(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_last(g) }) } @@ -731,7 +731,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Self { // @scalar-opt unsafe { self.as_materialized_series().agg_n_unique(groups) }.into() } @@ -742,7 +742,7 @@ impl Column { #[cfg(feature = "algorithm_group_by")] pub unsafe fn agg_quantile( &self, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Self { @@ -758,7 +758,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_median(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_median(g) }) } @@ -766,7 +766,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Self { + pub unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Self { // @scalar-opt unsafe { self.as_materialized_series().agg_var(groups, ddof) }.into() } @@ -775,7 +775,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self { + pub unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Self { // @scalar-opt unsafe { self.as_materialized_series().agg_std(groups, ddof) }.into() } @@ -784,7 +784,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub unsafe fn agg_list(&self, groups: &GroupsProxy) -> Self { + pub unsafe fn agg_list(&self, groups: &GroupsType) -> Self { // @scalar-opt unsafe { self.as_materialized_series().agg_list(groups) }.into() } @@ -793,7 +793,7 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Self { + pub fn agg_valid_count(&self, groups: &GroupsType) -> Self { // @partition-opt // @scalar-opt unsafe { self.as_materialized_series().agg_valid_count(groups) }.into() @@ -803,21 +803,21 @@ impl Column { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - pub fn agg_and(&self, groups: &GroupsProxy) -> Self { + pub fn agg_and(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_and(g) }) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - pub fn agg_or(&self, groups: &GroupsProxy) -> Self { + pub fn agg_or(&self, groups: &GroupsType) -> Self { self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_or(g) }) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - pub fn agg_xor(&self, groups: &GroupsProxy) -> Self { + pub fn agg_xor(&self, groups: &GroupsType) -> Self { // @partition-opt // @scalar-opt unsafe { self.as_materialized_series().agg_xor(groups) }.into() diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs index 3e71953c5753..993db687e7ea 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -8,7 +8,7 @@ pub trait AggList { /// # Safety /// /// groups should be in bounds - unsafe fn agg_list(&self, _groups: &GroupsProxy) -> Series; + unsafe fn agg_list(&self, _groups: &GroupsType) -> Series; } impl AggList for ChunkedArray @@ -16,11 +16,11 @@ where T: PolarsNumericType, ChunkedArray: IntoSeries, { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { let ca = self.rechunk(); match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let mut can_fast_explode = true; let arr = ca.downcast_iter().next().unwrap(); @@ -92,7 +92,7 @@ where } ca.into() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let mut can_fast_explode = true; let arr = ca.downcast_iter().next().unwrap(); let values = arr.values(); @@ -159,16 +159,16 @@ where } impl AggList for NullChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); for idx in groups.all().iter() { builder.append_with_len(idx.len()); } builder.finish().into_series() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); for [_, len] in groups { builder.append_with_len(*len as usize); @@ -180,39 +180,39 @@ impl AggList for NullChunked { } impl AggList for BooleanChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { agg_list_by_gather_and_offsets(self, groups) } } impl AggList for StringChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { agg_list_by_gather_and_offsets(self, groups) } } impl AggList for BinaryChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { agg_list_by_gather_and_offsets(self, groups) } } impl AggList for ListChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { agg_list_by_gather_and_offsets(self, groups) } } #[cfg(feature = "dtype-array")] impl AggList for ArrayChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { agg_list_by_gather_and_offsets(self, groups) } } #[cfg(feature = "object")] impl AggList for ObjectChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { let mut can_fast_explode = true; let mut offsets = Vec::::with_capacity(groups.len() + 1); let mut length_so_far = 0i64; @@ -279,7 +279,7 @@ impl AggList for ObjectChunked { #[cfg(feature = "dtype-struct")] impl AggList for StructChunked { - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { let ca = self.clone(); let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(self.len()); @@ -308,7 +308,7 @@ impl AggList for StructChunked { unsafe fn agg_list_by_gather_and_offsets( ca: &ChunkedArray, - groups: &GroupsProxy, + groups: &GroupsType, ) -> Series where ChunkedArray: ChunkTakeUnchecked, diff --git a/crates/polars-core/src/frame/group_by/aggregations/boolean.rs b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs index 36cd8e9a8d41..b4c0b585050f 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/boolean.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs @@ -20,7 +20,7 @@ where #[cfg(feature = "bitwise")] unsafe fn bitwise_agg( ca: &BooleanChunked, - groups: &GroupsProxy, + groups: &GroupsType, f: fn(&BooleanChunked) -> Option, ) -> Series { // Prevent a rechunk for every individual group. @@ -31,7 +31,7 @@ unsafe fn bitwise_agg( }; match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx_bool::<_>(groups, |(_, idx)| { + GroupsType::Idx(groups) => _agg_helper_idx_bool::<_>(groups, |(_, idx)| { debug_assert!(idx.len() <= s.len()); if idx.is_empty() { None @@ -40,7 +40,7 @@ unsafe fn bitwise_agg( f(&take) } }), - GroupsProxy::Slice { groups, .. } => _agg_helper_slice_bool::<_>(groups, |[first, len]| { + GroupsType::Slice { groups, .. } => _agg_helper_slice_bool::<_>(groups, |[first, len]| { debug_assert!(len <= s.len() as IdxSize); if len == 0 { None @@ -54,21 +54,21 @@ unsafe fn bitwise_agg( #[cfg(feature = "bitwise")] impl BooleanChunked { - pub(crate) unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) } - pub(crate) unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) } - pub(crate) unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) } } impl BooleanChunked { - pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series { // faster paths match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { @@ -83,7 +83,7 @@ impl BooleanChunked { let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { debug_assert!(idx.len() <= self.len()); if idx.is_empty() { None @@ -95,7 +95,7 @@ impl BooleanChunked { take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize) } }), - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => _agg_helper_slice_bool(groups_slice, |[first, len]| { @@ -111,7 +111,7 @@ impl BooleanChunked { }), } } - pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series { // faster paths match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { @@ -127,7 +127,7 @@ impl BooleanChunked { let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { debug_assert!(idx.len() <= self.len()); if idx.is_empty() { None @@ -139,7 +139,7 @@ impl BooleanChunked { take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize) } }), - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => _agg_helper_slice_bool(groups_slice, |[first, len]| { @@ -155,7 +155,7 @@ impl BooleanChunked { }), } } - pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing) .unwrap() .agg_sum(groups) diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index a5e5ee199e67..8c2f1cd4afae 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -15,7 +15,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_valid_count(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 && self.null_count() > 0 { self.rechunk() @@ -24,7 +24,7 @@ impl Series { }; match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { + GroupsType::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { debug_assert!(idx.len() <= s.len()); if idx.is_empty() { None @@ -35,7 +35,7 @@ impl Series { Some((take.len() - take.null_count()) as IdxSize) } }), - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { _agg_helper_slice::(groups, |[first, len]| { debug_assert!(len <= s.len() as IdxSize); if len == 0 { @@ -52,7 +52,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { self.rechunk() @@ -61,7 +61,7 @@ impl Series { }; let mut out = match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let indices = groups .iter() .map( @@ -77,7 +77,7 @@ impl Series { // SAFETY: groups are always in bounds. s.take_unchecked(&indices) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let indices = groups .iter() .map(|&[first, len]| if len == 0 { None } else { Some(first) }) @@ -93,7 +93,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { self.rechunk() @@ -102,18 +102,16 @@ impl Series { }; match groups { - GroupsProxy::Idx(groups) => { - agg_helper_idx_on_all_no_null::(groups, |idx| { - debug_assert!(idx.len() <= s.len()); - if idx.is_empty() { - 0 - } else { - let take = s.take_slice_unchecked(idx); - take.n_unique().unwrap() as IdxSize - } - }) - }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Idx(groups) => agg_helper_idx_on_all_no_null::(groups, |idx| { + debug_assert!(idx.len() <= s.len()); + if idx.is_empty() { + 0 + } else { + let take = s.take_slice_unchecked(idx); + take.n_unique().unwrap() as IdxSize + } + }), + GroupsType::Slice { groups, .. } => { _agg_helper_slice_no_null::(groups, |[first, len]| { debug_assert!(len <= s.len() as IdxSize); if len == 0 { @@ -128,7 +126,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { self.rechunk() @@ -180,7 +178,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { self.rechunk() @@ -236,7 +234,7 @@ impl Series { #[doc(hidden)] pub unsafe fn agg_quantile( &self, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Series { @@ -268,7 +266,7 @@ impl Series { } #[doc(hidden)] - pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series { + pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { self.rechunk() @@ -277,7 +275,7 @@ impl Series { }; let out = match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let indices = groups .all() .iter() @@ -291,7 +289,7 @@ impl Series { .collect_ca(PlSmallStr::EMPTY); s.take_unchecked(&indices) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let indices = groups .iter() .map(|&[first, len]| { diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 5abdd67b6d5f..43f02ebb0dc2 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -333,7 +333,7 @@ impl QuantileDispatcher for Float64Chunked { unsafe fn agg_quantile_generic( ca: &ChunkedArray, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Series @@ -349,7 +349,7 @@ where return Series::full_null(ca.name().clone(), groups.len(), ca.dtype()); } match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = ca.rechunk(); agg_helper_idx_on_all::(groups, |idx| { debug_assert!(idx.len() <= ca.len()); @@ -361,7 +361,7 @@ where take._quantile(quantile, method).unwrap_unchecked() }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if _use_rolling_kernels(groups, ca.chunks()) { // this cast is a no-op for floats let s = ca @@ -416,7 +416,7 @@ where } } -unsafe fn agg_median_generic(ca: &ChunkedArray, groups: &GroupsProxy) -> Series +unsafe fn agg_median_generic(ca: &ChunkedArray, groups: &GroupsType) -> Series where T: PolarsNumericType, ChunkedArray: QuantileDispatcher, @@ -425,7 +425,7 @@ where ::Native: num_traits::Float, { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = ca.rechunk(); agg_helper_idx_on_all::(groups, |idx| { debug_assert!(idx.len() <= ca.len()); @@ -436,7 +436,7 @@ where take._median() }) }, - GroupsProxy::Slice { .. } => { + GroupsType::Slice { .. } => { agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) }, } @@ -448,7 +448,7 @@ where #[cfg(feature = "bitwise")] unsafe fn bitwise_agg( ca: &ChunkedArray, - groups: &GroupsProxy, + groups: &GroupsType, f: fn(&ChunkedArray) -> Option, ) -> Series where @@ -463,7 +463,7 @@ where }; match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { + GroupsType::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { debug_assert!(idx.len() <= s.len()); if idx.is_empty() { None @@ -472,7 +472,7 @@ where f(&take) } }), - GroupsProxy::Slice { groups, .. } => _agg_helper_slice::(groups, |[first, len]| { + GroupsType::Slice { groups, .. } => _agg_helper_slice::(groups, |[first, len]| { debug_assert!(len <= s.len() as IdxSize); if len == 0 { None @@ -494,21 +494,21 @@ where /// # Safety /// /// No bounds checks on `groups`. - pub(crate) unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series { unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) } } /// # Safety /// /// No bounds checks on `groups`. - pub(crate) unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series { unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) } } /// # Safety /// /// No bounds checks on `groups`. - pub(crate) unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) } } } @@ -526,7 +526,7 @@ where + TakeExtremum, ChunkedArray: IntoSeries + ChunkAgg, { - pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series { // faster paths match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { @@ -538,7 +538,7 @@ where _ => {}, } match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -559,7 +559,7 @@ where } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -599,7 +599,7 @@ where } } - pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series { // faster paths match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { @@ -612,7 +612,7 @@ where } match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -633,7 +633,7 @@ where } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -673,9 +673,9 @@ where } } - pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -694,7 +694,7 @@ where } }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap(); let values = arr.values().as_slice(); @@ -743,9 +743,9 @@ where + ChunkAgg, T::Native: Pow, { - pub(crate) unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -786,7 +786,7 @@ where out.map(|flt| NumCast::from(flt).unwrap()) }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap(); let values = arr.values().as_slice(); @@ -823,13 +823,13 @@ where } } - pub(crate) unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series + pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series where ::Native: num_traits::Float, { let ca = &self.0.rechunk(); match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -846,7 +846,7 @@ where out.map(|flt| NumCast::from(flt).unwrap()) }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = self.downcast_iter().next().unwrap(); let values = arr.values().as_slice(); @@ -889,13 +889,13 @@ where }, } } - pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series + pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series where ::Native: num_traits::Float, { let ca = &self.0.rechunk(); match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let arr = ca.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; agg_helper_idx_on_all::(groups, |idx| { @@ -911,7 +911,7 @@ where out.map(|flt| NumCast::from(flt.sqrt()).unwrap()) }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if _use_rolling_kernels(groups, self.chunks()) { let arr = ca.downcast_iter().next().unwrap(); let values = arr.values().as_slice(); @@ -962,26 +962,26 @@ where impl Float32Chunked { pub(crate) unsafe fn agg_quantile( &self, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Series { agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) } - pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { agg_median_generic::<_, Float32Type>(self, groups) } } impl Float64Chunked { pub(crate) unsafe fn agg_quantile( &self, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Series { agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } - pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { agg_median_generic::<_, Float64Type>(self, groups) } } @@ -992,9 +992,9 @@ where ChunkedArray: IntoSeries + ChunkAgg + ChunkVar, T::Native: NumericNative + Ord, { - pub(crate) unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca = self.rechunk(); let arr = ca.downcast_get(0).unwrap(); _agg_helper_idx::(groups, |(first, idx)| { @@ -1041,7 +1041,7 @@ where } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -1067,9 +1067,9 @@ where } } - pub(crate) unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { + pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca_self = self.rechunk(); let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -1085,7 +1085,7 @@ where } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -1116,9 +1116,9 @@ where }, } } - pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { + pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca_self = self.rechunk(); let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -1135,7 +1135,7 @@ where out.map(|v| v.sqrt()) }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -1169,13 +1169,13 @@ where pub(crate) unsafe fn agg_quantile( &self, - groups: &GroupsProxy, + groups: &GroupsType, quantile: f64, method: QuantileMethod, ) -> Series { agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } - pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { agg_median_generic::<_, Float64Type>(self, groups) } } diff --git a/crates/polars-core/src/frame/group_by/aggregations/string.rs b/crates/polars-core/src/frame/group_by/aggregations/string.rs index d16e3aa950f3..dffb2f99975a 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/string.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/string.rs @@ -18,7 +18,7 @@ where impl BinaryChunked { #[allow(clippy::needless_lifetimes)] - pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsType) -> Series { // faster paths match (&self.is_sorted_flag(), &self.null_count()) { (IsSorted::Ascending, 0) => { @@ -31,7 +31,7 @@ impl BinaryChunked { } match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca_self = self.rechunk(); let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -57,7 +57,7 @@ impl BinaryChunked { } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => _agg_helper_slice_bin(groups_slice, |[first, len]| { @@ -80,7 +80,7 @@ impl BinaryChunked { } #[allow(clippy::needless_lifetimes)] - pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsType) -> Series { // faster paths match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { @@ -93,7 +93,7 @@ impl BinaryChunked { } match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca_self = self.rechunk(); let arr = ca_self.downcast_iter().next().unwrap(); let no_nulls = arr.null_count() == 0; @@ -119,7 +119,7 @@ impl BinaryChunked { } }) }, - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => _agg_helper_slice_bin(groups_slice, |[first, len]| { @@ -144,13 +144,13 @@ impl BinaryChunked { impl StringChunked { #[allow(clippy::needless_lifetimes)] - pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsType) -> Series { let out = self.as_binary().agg_min(groups); out.binary().unwrap().to_string_unchecked().into_series() } #[allow(clippy::needless_lifetimes)] - pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsProxy) -> Series { + pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsType) -> Series { let out = self.as_binary().agg_max(groups); out.binary().unwrap().to_string_unchecked().into_series() } diff --git a/crates/polars-core/src/frame/group_by/expr.rs b/crates/polars-core/src/frame/group_by/expr.rs index f35a04a5664f..7348fba2f7c2 100644 --- a/crates/polars-core/src/frame/group_by/expr.rs +++ b/crates/polars-core/src/frame/group_by/expr.rs @@ -2,7 +2,7 @@ use crate::prelude::*; pub trait PhysicalAggExpr { #[allow(clippy::ptr_arg)] - fn evaluate(&self, df: &DataFrame, groups: &GroupsProxy) -> PolarsResult; + fn evaluate(&self, df: &DataFrame, groups: &GroupPositions) -> PolarsResult; fn root_name(&self) -> PolarsResult<&PlSmallStr>; } diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 0af7b5159c8a..e2c0da03f0e7 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -23,7 +23,7 @@ fn get_init_size() -> usize { } } -fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsProxy { +fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsType { if sorted { // we can just take the first value, no need to flatten let mut out = if out.len() == 1 { @@ -60,19 +60,19 @@ fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsProxy { out.sort_unstable_by_key(|g| g.0); let mut idx = GroupsIdx::from_iter(out); idx.sorted = true; - GroupsProxy::Idx(idx) + GroupsType::Idx(idx) } else { // we can just take the first value, no need to flatten if out.len() == 1 { - GroupsProxy::Idx(GroupsIdx::from(out.pop().unwrap())) + GroupsType::Idx(GroupsIdx::from(out.pop().unwrap())) } else { // flattens - GroupsProxy::Idx(GroupsIdx::from(out)) + GroupsType::Idx(GroupsIdx::from(out)) } } } -pub(crate) fn group_by(keys: impl Iterator, sorted: bool) -> GroupsProxy +pub(crate) fn group_by(keys: impl Iterator, sorted: bool) -> GroupsType where K: TotalHash + TotalEq, { @@ -107,7 +107,7 @@ where } (first, groups) = hash_tbl.into_values().unzip(); } - GroupsProxy::Idx(GroupsIdx::new(first, groups, sorted)) + GroupsType::Idx(GroupsIdx::new(first, groups, sorted)) } // giving the slice info to the compiler is much @@ -117,7 +117,7 @@ pub(crate) fn group_by_threaded_slice( keys: Vec, n_partitions: usize, sorted: bool, -) -> GroupsProxy +) -> GroupsType where T: ToTotalOrd, ::TotalOrdItem: Send + Sync + Copy + DirtyHash, @@ -170,7 +170,7 @@ pub(crate) fn group_by_threaded_iter( keys: &[I], n_partitions: usize, sorted: bool, -) -> GroupsProxy +) -> GroupsType where I: IntoIterator + Send + Sync + Clone, I::IntoIter: ExactSizeIterator, diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 9e4a05df9764..af0a406643db 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -9,11 +9,11 @@ use crate::series::BitRepr; use crate::utils::flatten::flatten_par; /// Used to create the tuples for a group_by operation. -pub trait IntoGroupsProxy { +pub trait IntoGroupsType { /// Create the tuples need for a group_by operation. /// * The first value in the tuple is the first index of the group. /// * The second value in the tuple is the indexes of the groups including the first value. - fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { unimplemented!() } } @@ -23,7 +23,7 @@ fn group_multithreaded(ca: &ChunkedArray) -> bool { ca.len() > 1000 && POOL.current_num_threads() > 1 } -fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsProxy +fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsType where T: PolarsNumericType, T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, @@ -127,22 +127,22 @@ where } #[cfg(all(feature = "dtype-categorical", feature = "performant"))] -impl IntoGroupsProxy for CategoricalChunked { - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { +impl IntoGroupsType for CategoricalChunked { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { Ok(self.group_tuples_perfect(multithreaded, sorted)) } } -impl IntoGroupsProxy for ChunkedArray +impl IntoGroupsType for ChunkedArray where T: PolarsNumericType, T::Native: NumCast, { - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { // sorted path if self.is_sorted_ascending_flag() || self.is_sorted_descending_flag() { // don't have to pass `sorted` arg, GroupSlice is always sorted. - return Ok(GroupsProxy::Slice { + return Ok(GroupsType::Slice { groups: self.rechunk().create_groups_from_sorted(multithreaded), rolling: false, }); @@ -237,8 +237,8 @@ where Ok(out) } } -impl IntoGroupsProxy for BooleanChunked { - fn group_tuples(&self, mut multithreaded: bool, sorted: bool) -> PolarsResult { +impl IntoGroupsType for BooleanChunked { + fn group_tuples(&self, mut multithreaded: bool, sorted: bool) -> PolarsResult { multithreaded &= POOL.current_num_threads() > 1; #[cfg(feature = "performant")] @@ -260,20 +260,20 @@ impl IntoGroupsProxy for BooleanChunked { } } -impl IntoGroupsProxy for StringChunked { +impl IntoGroupsType for StringChunked { #[allow(clippy::needless_lifetimes)] - fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult { self.as_binary().group_tuples(multithreaded, sorted) } } -impl IntoGroupsProxy for BinaryChunked { +impl IntoGroupsType for BinaryChunked { #[allow(clippy::needless_lifetimes)] fn group_tuples<'a>( &'a self, mut multithreaded: bool, sorted: bool, - ) -> PolarsResult { + ) -> PolarsResult { multithreaded &= POOL.current_num_threads() > 1; let bh = self.to_bytes_hashes(multithreaded, Default::default()); @@ -289,13 +289,13 @@ impl IntoGroupsProxy for BinaryChunked { } } -impl IntoGroupsProxy for BinaryOffsetChunked { +impl IntoGroupsType for BinaryOffsetChunked { #[allow(clippy::needless_lifetimes)] fn group_tuples<'a>( &'a self, mut multithreaded: bool, sorted: bool, - ) -> PolarsResult { + ) -> PolarsResult { multithreaded &= POOL.current_num_threads() > 1; let bh = self.to_bytes_hashes(multithreaded, Default::default()); @@ -311,14 +311,14 @@ impl IntoGroupsProxy for BinaryOffsetChunked { } } -impl IntoGroupsProxy for ListChunked { +impl IntoGroupsType for ListChunked { #[allow(clippy::needless_lifetimes)] #[allow(unused_variables)] fn group_tuples<'a>( &'a self, mut multithreaded: bool, sorted: bool, - ) -> PolarsResult { + ) -> PolarsResult { multithreaded &= POOL.current_num_threads() > 1; let by = &[self.clone().into_column()]; let ca = if multithreaded { @@ -332,24 +332,20 @@ impl IntoGroupsProxy for ListChunked { } #[cfg(feature = "dtype-array")] -impl IntoGroupsProxy for ArrayChunked { +impl IntoGroupsType for ArrayChunked { #[allow(clippy::needless_lifetimes)] #[allow(unused_variables)] - fn group_tuples<'a>( - &'a self, - _multithreaded: bool, - _sorted: bool, - ) -> PolarsResult { + fn group_tuples<'a>(&'a self, _multithreaded: bool, _sorted: bool) -> PolarsResult { todo!("grouping FixedSizeList not yet supported") } } #[cfg(feature = "object")] -impl IntoGroupsProxy for ObjectChunked +impl IntoGroupsType for ObjectChunked where T: PolarsObject, { - fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult { Ok(group_by(self.into_iter(), sorted)) } } diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index f5d0b7dc4959..41f135a0f309 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -16,10 +16,10 @@ pub mod expr; pub(crate) mod hashing; mod into_groups; mod perfect; -mod proxy; +mod position; pub use into_groups::*; -pub use proxy::*; +pub use position::*; use crate::chunked_array::ops::row_encode::{ encode_rows_unordered, encode_rows_vertical_par_unordered, @@ -84,7 +84,7 @@ impl DataFrame { } else { vec![[0, self.height() as IdxSize]] }; - Ok(GroupsProxy::Slice { + Ok(GroupsType::Slice { groups, rolling: false, }) @@ -98,7 +98,7 @@ impl DataFrame { rows.group_tuples(multithreaded, sorted) } }; - Ok(GroupBy::new(self, by, groups?, None)) + Ok(GroupBy::new(self, by, groups?.into_sliceable(), None)) } /// Group DataFrame using a Series column. @@ -184,20 +184,20 @@ impl DataFrame { /// ``` /// #[derive(Debug, Clone)] -pub struct GroupBy<'df> { - pub df: &'df DataFrame, +pub struct GroupBy<'a> { + pub df: &'a DataFrame, pub(crate) selected_keys: Vec, // [first idx, [other idx]] - groups: GroupsProxy, + groups: GroupPositions, // columns selected for aggregation pub(crate) selected_agg: Option>, } -impl<'df> GroupBy<'df> { +impl<'a> GroupBy<'a> { pub fn new( - df: &'df DataFrame, + df: &'a DataFrame, by: Vec, - groups: GroupsProxy, + groups: GroupPositions, selected_agg: Option>, ) -> Self { GroupBy { @@ -223,7 +223,7 @@ impl<'df> GroupBy<'df> { /// The Vec returned contains: /// (first_idx, [`Vec`]) /// Where second value in the tuple is a vector with all matching indexes. - pub fn get_groups(&self) -> &GroupsProxy { + pub fn get_groups(&self) -> &GroupPositions { &self.groups } @@ -235,15 +235,15 @@ impl<'df> GroupBy<'df> { /// # Safety /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`]. /// If you mutate it, you must hold that invariant. - pub unsafe fn get_groups_mut(&mut self) -> &mut GroupsProxy { + pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions { &mut self.groups } - pub fn take_groups(self) -> GroupsProxy { + pub fn take_groups(self) -> GroupPositions { self.groups } - pub fn take_groups_mut(&mut self) -> GroupsProxy { + pub fn take_groups_mut(&mut self) -> GroupPositions { std::mem::take(&mut self.groups) } @@ -264,7 +264,7 @@ impl<'df> GroupBy<'df> { .map(Column::as_materialized_series) .map(|s| { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { // SAFETY: groups are always in bounds. let mut out = unsafe { s.take_slice_unchecked(groups.first()) }; if groups.sorted { @@ -272,7 +272,7 @@ impl<'df> GroupBy<'df> { }; out }, - GroupsProxy::Slice { groups, rolling } => { + GroupsType::Slice { groups, rolling } => { if *rolling && !groups.is_empty() { // Groups can be sliced. let offset = groups[0][0]; @@ -846,7 +846,7 @@ impl<'df> GroupBy<'df> { match slice { None => self, Some((offset, length)) => { - self.groups = (*self.groups.slice(offset, length)).clone(); + self.groups = (self.groups.slice(offset, length)).clone(); self.selected_keys = self.keys_sliced(slice); self }, diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index a020aa4fb37e..a61004ed6b99 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -27,7 +27,7 @@ where num_groups: usize, mut multithreaded: bool, group_capacity: usize, - ) -> GroupsProxy { + ) -> GroupsType { multithreaded &= POOL.current_num_threads() > 1; // The latest index will be used for the null sentinel. let len = if self.null_count() > 0 { @@ -152,7 +152,7 @@ where // NOTE! we set sorted here! // this happens to be true for `fast_unique` categoricals - GroupsProxy::Idx(GroupsIdx::new(first, groups, true)) + GroupsType::Idx(GroupsIdx::new(first, groups, true)) } } @@ -160,10 +160,10 @@ where // Special implementation so that cats can be processed in a single pass impl CategoricalChunked { // Use the indexes as perfect groups - pub fn group_tuples_perfect(&self, multithreaded: bool, sorted: bool) -> GroupsProxy { + pub fn group_tuples_perfect(&self, multithreaded: bool, sorted: bool) -> GroupsType { let rev_map = self.get_rev_map(); if self.is_empty() { - return GroupsProxy::Idx(GroupsIdx::new(vec![], vec![], true)); + return GroupsType::Idx(GroupsIdx::new(vec![], vec![], true)); } let cats = self.physical(); diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/position.rs similarity index 71% rename from crates/polars-core/src/frame/group_by/proxy.rs rename to crates/polars-core/src/frame/group_by/position.rs index 63b1a8022108..58734cd50b74 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/position.rs @@ -1,5 +1,5 @@ use std::mem::ManuallyDrop; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use arrow::offset::OffsetsBuffer; use polars_utils::idx_vec::IdxVec; @@ -239,7 +239,7 @@ impl IntoParallelIterator for GroupsIdx { pub type GroupsSlice = Vec<[IdxSize; 2]>; #[derive(Debug, Clone, PartialEq, Eq)] -pub enum GroupsProxy { +pub enum GroupsType { Idx(GroupsIdx), /// Slice is always sorted in ascending order. Slice { @@ -250,17 +250,17 @@ pub enum GroupsProxy { }, } -impl Default for GroupsProxy { +impl Default for GroupsType { fn default() -> Self { - GroupsProxy::Idx(GroupsIdx::default()) + GroupsType::Idx(GroupsIdx::default()) } } -impl GroupsProxy { +impl GroupsType { pub fn into_idx(self) -> GroupsIdx { match self { - GroupsProxy::Idx(groups) => groups, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Idx(groups) => groups, + GroupsType::Slice { groups, .. } => { polars_warn!("Had to reallocate groups, missed an optimization opportunity. Please open an issue."); groups .iter() @@ -276,7 +276,7 @@ impl GroupsProxy { ) -> (Option, OffsetsBuffer, bool) { let mut can_fast_explode = true; match self { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let mut list_offset = Vec::with_capacity(self.len() + 1); let mut gather_offsets = Vec::with_capacity(total_len); @@ -298,7 +298,7 @@ impl GroupsProxy { ) } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let mut list_offset = Vec::with_capacity(self.len() + 1); let mut gather_offsets = Vec::with_capacity(total_len); let mut len_so_far = 0i64; @@ -325,18 +325,18 @@ impl GroupsProxy { } } - pub fn iter(&self) -> GroupsProxyIter { - GroupsProxyIter::new(self) + pub fn iter(&self) -> GroupsTypeIter { + GroupsTypeIter::new(self) } pub fn sort(&mut self) { match self { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { if !groups.is_sorted_flag() { groups.sort() } }, - GroupsProxy::Slice { .. } => { + GroupsType::Slice { .. } => { // invariant of the type }, } @@ -344,15 +344,15 @@ impl GroupsProxy { pub(crate) fn is_sorted_flag(&self) -> bool { match self { - GroupsProxy::Idx(groups) => groups.is_sorted_flag(), - GroupsProxy::Slice { .. } => true, + GroupsType::Idx(groups) => groups.is_sorted_flag(), + GroupsType::Slice { .. } => true, } } pub fn take_group_firsts(self) -> Vec { match self { - GroupsProxy::Idx(mut groups) => std::mem::take(&mut groups.first), - GroupsProxy::Slice { groups, .. } => { + GroupsType::Idx(mut groups) => std::mem::take(&mut groups.first), + GroupsType::Slice { groups, .. } => { groups.into_iter().map(|[first, _len]| first).collect() }, } @@ -363,20 +363,20 @@ impl GroupsProxy { /// all groups have members. pub unsafe fn take_group_lasts(self) -> Vec { match self { - GroupsProxy::Idx(groups) => groups + GroupsType::Idx(groups) => groups .all .iter() .map(|idx| *idx.get_unchecked(idx.len() - 1)) .collect(), - GroupsProxy::Slice { groups, .. } => groups + GroupsType::Slice { groups, .. } => groups .into_iter() .map(|[first, len]| first + len - 1) .collect(), } } - pub fn par_iter(&self) -> GroupsProxyParIter { - GroupsProxyParIter::new(self) + pub fn par_iter(&self) -> GroupsTypeParIter { + GroupsTypeParIter::new(self) } /// Get a reference to the `GroupsIdx`. @@ -386,8 +386,8 @@ impl GroupsProxy { /// panics if the groups are a slice. pub fn unwrap_idx(&self) -> &GroupsIdx { match self { - GroupsProxy::Idx(groups) => groups, - GroupsProxy::Slice { .. } => panic!("groups are slices not index"), + GroupsType::Idx(groups) => groups, + GroupsType::Slice { .. } => panic!("groups are slices not index"), } } @@ -398,19 +398,19 @@ impl GroupsProxy { /// panics if the groups are an idx. pub fn unwrap_slice(&self) -> &GroupsSlice { match self { - GroupsProxy::Slice { groups, .. } => groups, - GroupsProxy::Idx(_) => panic!("groups are index not slices"), + GroupsType::Slice { groups, .. } => groups, + GroupsType::Idx(_) => panic!("groups are index not slices"), } } pub fn get(&self, index: usize) -> GroupsIndicator { match self { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let first = groups.first[index]; let all = &groups.all[index]; GroupsIndicator::Idx((first, all)) }, - GroupsProxy::Slice { groups, .. } => GroupsIndicator::Slice(groups[index]), + GroupsType::Slice { groups, .. } => GroupsIndicator::Slice(groups[index]), } } @@ -421,15 +421,15 @@ impl GroupsProxy { /// panics if the groups are a slice. pub fn idx_mut(&mut self) -> &mut GroupsIdx { match self { - GroupsProxy::Idx(groups) => groups, - GroupsProxy::Slice { .. } => panic!("groups are slices not index"), + GroupsType::Idx(groups) => groups, + GroupsType::Slice { .. } => panic!("groups are slices not index"), } } pub fn len(&self) -> usize { match self { - GroupsProxy::Idx(groups) => groups.len(), - GroupsProxy::Slice { groups, .. } => groups.len(), + GroupsType::Idx(groups) => groups.len(), + GroupsType::Slice { groups, .. } => groups.len(), } } @@ -439,14 +439,14 @@ impl GroupsProxy { pub fn group_count(&self) -> IdxCa { match self { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let ca: NoNull = groups .iter() .map(|(_first, idx)| idx.len() as IdxSize) .collect_trusted(); ca.into_inner() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let ca: NoNull = groups.iter().map(|[_first, len]| *len).collect_trusted(); ca.into_inner() }, @@ -454,14 +454,14 @@ impl GroupsProxy { } pub fn as_list_chunked(&self) -> ListChunked { match self { - GroupsProxy::Idx(groups) => groups + GroupsType::Idx(groups) => groups .iter() .map(|(_first, idx)| { let ca: NoNull = idx.iter().map(|&v| v as IdxSize).collect(); ca.into_inner().into_series() }) .collect_trusted(), - GroupsProxy::Slice { groups, .. } => groups + GroupsType::Slice { groups, .. } => groups .iter() .map(|&[first, len]| { let ca: NoNull = (first..first + len).collect_trusted(); @@ -471,73 +471,15 @@ impl GroupsProxy { } } - pub fn unroll(self) -> GroupsProxy { - match self { - GroupsProxy::Idx(_) => self, - GroupsProxy::Slice { rolling: false, .. } => self, - GroupsProxy::Slice { mut groups, .. } => { - let mut offset = 0 as IdxSize; - for g in groups.iter_mut() { - g[0] = offset; - offset += g[1]; - } - GroupsProxy::Slice { - groups, - rolling: false, - } - }, - } - } - - pub fn slice(&self, offset: i64, len: usize) -> SlicedGroups<'_> { - // SAFETY: - // we create new `Vec`s from the sliced groups. But we wrap them in ManuallyDrop - // so that we never call drop on them. - // These groups lifetimes are bounded to the `self`. This must remain valid - // for the scope of the aggregation. - let sliced = match self { - GroupsProxy::Idx(groups) => { - let first = unsafe { - let first = slice_slice(groups.first(), offset, len); - let ptr = first.as_ptr() as *mut _; - Vec::from_raw_parts(ptr, first.len(), first.len()) - }; - - let all = unsafe { - let all = slice_slice(groups.all(), offset, len); - let ptr = all.as_ptr() as *mut _; - Vec::from_raw_parts(ptr, all.len(), all.len()) - }; - ManuallyDrop::new(GroupsProxy::Idx(GroupsIdx::new( - first, - all, - groups.is_sorted_flag(), - ))) - }, - GroupsProxy::Slice { groups, rolling } => { - let groups = unsafe { - let groups = slice_slice(groups, offset, len); - let ptr = groups.as_ptr() as *mut _; - Vec::from_raw_parts(ptr, groups.len(), groups.len()) - }; - - ManuallyDrop::new(GroupsProxy::Slice { - groups, - rolling: *rolling, - }) - }, - }; - - SlicedGroups { - sliced, - borrowed: self, - } + pub fn into_sliceable(self) -> GroupPositions { + let len = self.len(); + slice_groups(Arc::new(self), 0, len) } } -impl From for GroupsProxy { +impl From for GroupsType { fn from(groups: GroupsIdx) -> Self { - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) } } @@ -564,21 +506,21 @@ impl GroupsIndicator<'_> { } } -pub struct GroupsProxyIter<'a> { - vals: &'a GroupsProxy, +pub struct GroupsTypeIter<'a> { + vals: &'a GroupsType, len: usize, idx: usize, } -impl<'a> GroupsProxyIter<'a> { - fn new(vals: &'a GroupsProxy) -> Self { +impl<'a> GroupsTypeIter<'a> { + fn new(vals: &'a GroupsType) -> Self { let len = vals.len(); let idx = 0; - GroupsProxyIter { vals, len, idx } + GroupsTypeIter { vals, len, idx } } } -impl<'a> Iterator for GroupsProxyIter<'a> { +impl<'a> Iterator for GroupsTypeIter<'a> { type Item = GroupsIndicator<'a>; fn nth(&mut self, n: usize) -> Option { @@ -593,11 +535,11 @@ impl<'a> Iterator for GroupsProxyIter<'a> { let out = unsafe { match self.vals { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let item = groups.get_unchecked(self.idx); Some(GroupsIndicator::Idx(item)) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { Some(GroupsIndicator::Slice(*groups.get_unchecked(self.idx))) }, } @@ -607,19 +549,19 @@ impl<'a> Iterator for GroupsProxyIter<'a> { } } -pub struct GroupsProxyParIter<'a> { - vals: &'a GroupsProxy, +pub struct GroupsTypeParIter<'a> { + vals: &'a GroupsType, len: usize, } -impl<'a> GroupsProxyParIter<'a> { - fn new(vals: &'a GroupsProxy) -> Self { +impl<'a> GroupsTypeParIter<'a> { + fn new(vals: &'a GroupsType) -> Self { let len = vals.len(); - GroupsProxyParIter { vals, len } + GroupsTypeParIter { vals, len } } } -impl<'a> ParallelIterator for GroupsProxyParIter<'a> { +impl<'a> ParallelIterator for GroupsTypeParIter<'a> { type Item = GroupsIndicator<'a>; fn drive_unindexed(self, consumer: C) -> C::Result @@ -630,8 +572,8 @@ impl<'a> ParallelIterator for GroupsProxyParIter<'a> { .into_par_iter() .map(|i| unsafe { match self.vals { - GroupsProxy::Idx(groups) => GroupsIndicator::Idx(groups.get_unchecked(i)), - GroupsProxy::Slice { groups, .. } => { + GroupsType::Idx(groups) => GroupsIndicator::Idx(groups.get_unchecked(i)), + GroupsType::Slice { groups, .. } => { GroupsIndicator::Slice(*groups.get_unchecked(i)) }, } @@ -640,17 +582,123 @@ impl<'a> ParallelIterator for GroupsProxyParIter<'a> { } } -pub struct SlicedGroups<'a> { - sliced: ManuallyDrop, - #[allow(dead_code)] - // we need the lifetime to ensure the slice remains valid - borrowed: &'a GroupsProxy, +#[derive(Clone, Debug)] +pub struct GroupPositions { + sliced: ManuallyDrop, + // Unsliced buffer + original: Arc, + offset: i64, + len: usize, } -impl Deref for SlicedGroups<'_> { - type Target = GroupsProxy; +impl PartialEq for GroupPositions { + fn eq(&self, other: &Self) -> bool { + self.offset == other.offset && self.len == other.len && self.sliced == other.sliced + } +} + +impl AsRef for GroupPositions { + fn as_ref(&self) -> &GroupsType { + self.sliced.deref() + } +} + +impl Deref for GroupPositions { + type Target = GroupsType; fn deref(&self) -> &Self::Target { self.sliced.deref() } } + +impl Default for GroupPositions { + fn default() -> Self { + GroupsType::default().into_sliceable() + } +} + +impl GroupPositions { + pub fn slice(&self, offset: i64, len: usize) -> Self { + let offset = self.offset + offset; + assert!(len <= self.len); + slice_groups(self.original.clone(), offset, len) + } + + pub fn sort(&mut self) { + if !self.as_ref().is_sorted_flag() { + let original = Arc::make_mut(&mut self.original); + original.sort(); + + self.sliced = slice_groups_inner(original, self.offset, self.len); + } + } + + pub fn unroll(mut self) -> GroupPositions { + match self.sliced.deref_mut() { + GroupsType::Idx(_) => self, + GroupsType::Slice { rolling: false, .. } => self, + GroupsType::Slice { + groups, rolling, .. + } => { + let mut offset = 0 as IdxSize; + for g in groups.iter_mut() { + g[0] = offset; + offset += g[1]; + } + *rolling = false; + self + }, + } + } +} + +fn slice_groups_inner(g: &GroupsType, offset: i64, len: usize) -> ManuallyDrop { + // SAFETY: + // we create new `Vec`s from the sliced groups. But we wrap them in ManuallyDrop + // so that we never call drop on them. + // These groups lifetimes are bounded to the `g`. This must remain valid + // for the scope of the aggregation. + match g { + GroupsType::Idx(groups) => { + let first = unsafe { + let first = slice_slice(groups.first(), offset, len); + let ptr = first.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, first.len(), first.len()) + }; + + let all = unsafe { + let all = slice_slice(groups.all(), offset, len); + let ptr = all.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, all.len(), all.len()) + }; + ManuallyDrop::new(GroupsType::Idx(GroupsIdx::new( + first, + all, + groups.is_sorted_flag(), + ))) + }, + GroupsType::Slice { groups, rolling } => { + let groups = unsafe { + let groups = slice_slice(groups, offset, len); + let ptr = groups.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, groups.len(), groups.len()) + }; + + ManuallyDrop::new(GroupsType::Slice { + groups, + rolling: *rolling, + }) + }, + } +} + +fn slice_groups(g: Arc, offset: i64, len: usize) -> GroupPositions { + let sliced = slice_groups_inner(g.as_ref(), offset, len); + + GroupPositions { + sliced, + original: g, + offset, + len, + } +} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 1a057ddef800..aa1d5ebf18dc 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -3039,7 +3039,7 @@ impl DataFrame { #[cfg(feature = "algorithm_group_by")] pub fn is_unique(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names_owned())?; - let groups = gb.take_groups(); + let groups = gb.get_groups(); Ok(is_unique_helper( groups, self.height() as IdxSize, @@ -3064,7 +3064,7 @@ impl DataFrame { #[cfg(feature = "algorithm_group_by")] pub fn is_duplicated(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names_owned())?; - let groups = gb.take_groups(); + let groups = gb.get_groups(); Ok(is_unique_helper( groups, self.height() as IdxSize, @@ -3174,8 +3174,8 @@ impl DataFrame { // don't parallelize this // there is a lot of parallelization in take and this may easily SO POOL.install(|| { - match groups { - GroupsProxy::Idx(idx) => { + match groups.as_ref() { + GroupsType::Idx(idx) => { // Rechunk as the gather may rechunk for every group #17562. let mut df = df.clone(); df.as_single_chunk_par(); @@ -3184,14 +3184,14 @@ impl DataFrame { .map(|(_, group)| { // groups are in bounds unsafe { - df._take_unchecked_slice_sorted(&group, false, IsSorted::Ascending) + df._take_unchecked_slice_sorted(group, false, IsSorted::Ascending) } }) .collect()) }, - GroupsProxy::Slice { groups, .. } => Ok(groups + GroupsType::Slice { groups, .. } => Ok(groups .into_par_iter() - .map(|[first, len]| df.slice(first as i64, len as usize)) + .map(|[first, len]| df.slice(*first as i64, *len as usize)) .collect()), } }) diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index 156076a4c295..f67dc7f301c8 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -41,13 +41,13 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn add_to(&self, rhs: &Series) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 52bdd857edb1..b37756a2c087 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -52,17 +52,17 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups) } @@ -82,8 +82,8 @@ impl private::PrivateSeries for SeriesWrap { NumOpsDispatch::remainder(&self.0, rhs) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 6cb2a4b3e86c..4976240f776f 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -48,8 +48,8 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index ab92d66bbf40..58a3da28e9d2 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -51,33 +51,33 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.0.agg_sum(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { + unsafe fn agg_std(&self, groups: &GroupsType, _ddof: u8) -> Series { self.0 .cast_with_options(&DataType::Float64, CastOptions::Overflowing) .unwrap() .agg_std(groups, _ddof) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { + unsafe fn agg_var(&self, groups: &GroupsType, _ddof: u8) -> Series { self.0 .cast_with_options(&DataType::Float64, CastOptions::Overflowing) .unwrap() @@ -85,21 +85,21 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "bitwise")] - unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { self.0.agg_and(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { self.0.agg_or(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { self.0.agg_xor(groups) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index fc9dc0eb5760..1b2a8a77b49f 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -98,7 +98,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect let list = self.0.physical().agg_list(groups); let mut list = list.list().unwrap().clone(); @@ -107,7 +107,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { #[cfg(feature = "performant")] { Ok(self.0.group_tuples_perfect(multithreaded, sorted)) diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index e5e319c3fc16..021a4a7e18cf 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -69,17 +69,17 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups).into_date().into_series() } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups).into_date().into_series() } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 .agg_list(groups) @@ -133,7 +133,7 @@ impl private::PrivateSeries for SeriesWrap { polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index c547be64b504..ee4d9022782a 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -63,7 +63,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0 .agg_min(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) @@ -71,14 +71,14 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0 .agg_max(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 .agg_list(groups) @@ -131,7 +131,7 @@ impl private::PrivateSeries for SeriesWrap { polars_bail!(opq = rem, self.dtype(), rhs.dtype()); } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 0494da7bffd2..6e477ccf6c3f 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -140,22 +140,22 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.agg_helper(|ca| ca.agg_sum(groups)) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.agg_helper(|ca| ca.agg_min(groups)) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.agg_helper(|ca| ca.agg_max(groups)) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.agg_helper(|ca| ca.agg_list(groups)) } @@ -176,7 +176,7 @@ impl private::PrivateSeries for SeriesWrap { ((&self.0) / rhs).map(|ca| ca.into_series()) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index c7563221cfa3..51426f1b94e6 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -70,7 +70,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0 .agg_min(groups) .into_duration(self.0.time_unit()) @@ -78,7 +78,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0 .agg_max(groups) .into_duration(self.0.time_unit()) @@ -86,7 +86,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.0 .agg_sum(groups) .into_duration(self.0.time_unit()) @@ -94,7 +94,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { self.0 .agg_std(groups, ddof) // cast f64 back to physical type @@ -105,7 +105,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { self.0 .agg_var(groups, ddof) // cast f64 back to physical type @@ -116,7 +116,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 .agg_list(groups) @@ -245,7 +245,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series()) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index e787c158d5e2..9ccbb1d8d958 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -67,45 +67,45 @@ macro_rules! impl_dyn_series { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.0.agg_sum(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { self.agg_std(groups, ddof) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { self.agg_var(groups, ddof) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { self.0.agg_and(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { self.0.agg_or(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { self.0.agg_xor(groups) } @@ -125,8 +125,8 @@ macro_rules! impl_dyn_series { NumOpsDispatch::remainder(&self.0, rhs) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index bfee61814fcc..74d61d9d91a8 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -32,13 +32,13 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn into_total_eq_inner<'a>(&'a self) -> Box { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 4214f645f381..9df0e7695127 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -130,17 +130,17 @@ macro_rules! impl_dyn_series { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { use DataType::*; match self.dtype() { Int8 | UInt8 | Int16 | UInt16 => self @@ -152,30 +152,30 @@ macro_rules! impl_dyn_series { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { self.0.agg_std(groups, ddof) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { self.0.agg_var(groups, ddof) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { self.0.agg_and(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { self.0.agg_or(groups) } #[cfg(feature = "bitwise")] - unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { self.0.agg_xor(groups) } @@ -195,8 +195,8 @@ macro_rules! impl_dyn_series { NumOpsDispatch::remainder(&self.0, rhs) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index d844e24aa8c1..d54cfc806066 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -106,11 +106,11 @@ impl PrivateSeries for NullChunked { } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { Ok(if self.is_empty() { - GroupsProxy::default() + GroupsType::default() } else { - GroupsProxy::Slice { + GroupsType::Slice { groups: vec![[0, self.length]], rolling: false, } @@ -118,7 +118,7 @@ impl PrivateSeries for NullChunked { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { AggList::agg_list(self, groups) } diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index b70ef3f074b6..bc753f9b06f4 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -47,7 +47,7 @@ where fn _get_flags(&self) -> StatisticsFlags { self.0.get_flags() } - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } @@ -73,8 +73,8 @@ where } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index c98337af075d..1dffefc679b1 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -51,17 +51,17 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups) } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups) } @@ -81,8 +81,8 @@ impl private::PrivateSeries for SeriesWrap { NumOpsDispatch::remainder(&self.0, rhs) } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) } fn arg_sort_multiple( diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 969601c338c6..d741747bc32d 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -43,7 +43,7 @@ impl PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { let ca = self.0.get_row_encoded(Default::default())?; ca.group_tuples(multithreaded, sorted) } @@ -63,7 +63,7 @@ impl PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { self.0.agg_list(groups) } diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 247f48091f42..d0f3a7e0571a 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -69,17 +69,17 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0.agg_min(groups).into_time().into_series() } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0.agg_max(groups).into_time().into_series() } #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 .agg_list(groups) @@ -115,7 +115,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index af89119feb18..c1b8d3f97764 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -96,41 +96,41 @@ pub(crate) mod private { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { + unsafe fn agg_std(&self, groups: &GroupsType, _ddof: u8) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { + unsafe fn agg_var(&self, groups: &GroupsType, _ddof: u8) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// # Safety /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } @@ -138,7 +138,7 @@ pub(crate) mod private { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - unsafe fn agg_and(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } @@ -146,7 +146,7 @@ pub(crate) mod private { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - unsafe fn agg_or(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } @@ -154,7 +154,7 @@ pub(crate) mod private { /// /// Does no bounds checks, groups must be correct. #[cfg(feature = "bitwise")] - unsafe fn agg_xor(&self, groups: &GroupsProxy) -> Series { + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } @@ -174,7 +174,7 @@ pub(crate) mod private { polars_bail!(opq = remainder, self._dtype()); } #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { polars_bail!(opq = group_tuples, self._dtype()); } #[cfg(feature = "zip_with")] diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index e1e22e1a63c2..665c843d6e63 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -157,7 +157,7 @@ impl PhysicalExpr for AggregationExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.input.evaluate_on_groups(df, groups, state)?; @@ -287,8 +287,8 @@ impl PhysicalExpr for AggregationExpr { let out: IdxCa = if matches!(s.dtype(), &DataType::Null) { IdxCa::full(s.name().clone(), 0, groups.len()) } else { - match groups.as_ref() { - GroupsProxy::Idx(idx) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(idx) => { let s = s.rechunk(); // @scalar-opt // @partition-opt @@ -307,7 +307,7 @@ impl PhysicalExpr for AggregationExpr { }) .collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { // Slice and use computed null count groups .iter() @@ -459,7 +459,7 @@ impl PartitionedAggregation for AggregationExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let expr = self.input.as_partitioned_aggregator().unwrap(); @@ -547,7 +547,7 @@ impl PartitionedAggregation for AggregationExpr { fn finalize( &self, partitioned: Column, - groups: &GroupsProxy, + groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { match self.agg_type.groupby { @@ -605,8 +605,8 @@ impl PartitionedAggregation for AggregationExpr { Ok(()) }; - match groups { - GroupsProxy::Idx(groups) => { + match groups.as_ref() { + GroupsType::Idx(groups) => { for (_, idx) in groups { let ca = unsafe { // SAFETY: @@ -616,7 +616,7 @@ impl PartitionedAggregation for AggregationExpr { process_group(ca)?; } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { for [first, len] in groups { let len = *len as usize; let ca = ca.slice(*first as i64, len); @@ -712,7 +712,7 @@ impl PhysicalExpr for AggQuantileExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.input.evaluate_on_groups(df, groups, state)?; diff --git a/crates/polars-expr/src/expressions/alias.rs b/crates/polars-expr/src/expressions/alias.rs index 410ca00448a4..263f8a85c9c3 100644 --- a/crates/polars-expr/src/expressions/alias.rs +++ b/crates/polars-expr/src/expressions/alias.rs @@ -44,7 +44,7 @@ impl PhysicalExpr for AliasExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?; @@ -88,7 +88,7 @@ impl PartitionedAggregation for AliasExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); @@ -99,7 +99,7 @@ impl PartitionedAggregation for AliasExpr { fn finalize( &self, partitioned: Column, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index c3242547dfca..7b95a19c18b9 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -75,7 +75,7 @@ impl ApplyExpr { fn prepare_multiple_inputs<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult>> { let f = |e: &Arc| e.evaluate_on_groups(df, groups, state); @@ -362,7 +362,7 @@ impl PhysicalExpr for ApplyExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { polars_ensure!( @@ -699,7 +699,7 @@ impl PartitionedAggregation for ApplyExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let a = self.inputs[0].as_partitioned_aggregator().unwrap(); @@ -716,7 +716,7 @@ impl PartitionedAggregation for ApplyExpr { fn finalize( &self, partitioned: Column, - _groups: &GroupsProxy, + _groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { Ok(partitioned) diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 44c0912fbeb2..694afd0bbaff 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -224,7 +224,7 @@ impl PhysicalExpr for BinaryExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let (result_a, result_b) = POOL.install(|| { @@ -526,7 +526,7 @@ impl PartitionedAggregation for BinaryExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let left = self.left.as_partitioned_aggregator().unwrap(); @@ -539,7 +539,7 @@ impl PartitionedAggregation for BinaryExpr { fn finalize( &self, partitioned: Column, - _groups: &GroupsProxy, + _groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { Ok(partitioned) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index 623854d35b11..ea02cc0cb621 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -46,7 +46,7 @@ impl PhysicalExpr for CastExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.input.evaluate_on_groups(df, groups, state)?; @@ -111,7 +111,7 @@ impl PartitionedAggregation for CastExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let e = self.input.as_partitioned_aggregator().unwrap(); @@ -121,7 +121,7 @@ impl PartitionedAggregation for CastExpr { fn finalize( &self, partitioned: Column, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let agg = self.input.as_partitioned_aggregator().unwrap(); diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs index 0cdb3e54bbfd..4e31b9a557cd 100644 --- a/crates/polars-expr/src/expressions/column.rs +++ b/crates/polars-expr/src/expressions/column.rs @@ -168,7 +168,7 @@ impl PhysicalExpr for ColumnExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let c = self.evaluate(df, state)?; @@ -211,7 +211,7 @@ impl PartitionedAggregation for ColumnExpr { fn evaluate_partitioned( &self, df: &DataFrame, - _groups: &GroupsProxy, + _groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { self.evaluate(df, state) @@ -220,7 +220,7 @@ impl PartitionedAggregation for ColumnExpr { fn finalize( &self, partitioned: Column, - _groups: &GroupsProxy, + _groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { Ok(partitioned) diff --git a/crates/polars-expr/src/expressions/count.rs b/crates/polars-expr/src/expressions/count.rs index 118334126ecf..94b73a576713 100644 --- a/crates/polars-expr/src/expressions/count.rs +++ b/crates/polars-expr/src/expressions/count.rs @@ -28,7 +28,7 @@ impl PhysicalExpr for CountExpr { fn evaluate_on_groups<'a>( &self, _df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, _state: &ExecutionState, ) -> PolarsResult> { let ca = groups.group_count().with_name(PlSmallStr::from_static(LEN)); @@ -56,7 +56,7 @@ impl PartitionedAggregation for CountExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { self.evaluate_on_groups(df, groups, state) @@ -68,7 +68,7 @@ impl PartitionedAggregation for CountExpr { fn finalize( &self, partitioned: Column, - groups: &GroupsProxy, + groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { // SAFETY: groups are in bounds. diff --git a/crates/polars-expr/src/expressions/filter.rs b/crates/polars-expr/src/expressions/filter.rs index 240e5a83be62..42f4e6a8bf99 100644 --- a/crates/polars-expr/src/expressions/filter.rs +++ b/crates/polars-expr/src/expressions/filter.rs @@ -38,7 +38,7 @@ impl PhysicalExpr for FilterExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let ac_s_f = || self.input.evaluate_on_groups(df, groups, state); @@ -89,7 +89,7 @@ impl PhysicalExpr for FilterExpr { // All values false - create empty groups. let groups = if !predicate.any() { let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::>(); - GroupsProxy::Slice { + GroupsType::Slice { groups, rolling: false, } @@ -99,8 +99,8 @@ impl PhysicalExpr for FilterExpr { let predicate = predicate.rechunk(); let predicate = predicate.downcast_iter().next().unwrap(); POOL.install(|| { - match groups.as_ref() { - GroupsProxy::Idx(groups) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .par_iter() .map(|(first, idx)| unsafe { @@ -118,9 +118,9 @@ impl PhysicalExpr for FilterExpr { }) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .par_iter() .map(|&[first, len]| unsafe { @@ -135,13 +135,14 @@ impl PhysicalExpr for FilterExpr { (*idx.first().unwrap_or(&first), idx) }) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, } }) }; - ac_s.with_groups(groups).set_original_len(false); + ac_s.with_groups(groups.into_sliceable()) + .set_original_len(false); Ok(ac_s) } } diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index e38b27aaeacc..16daef800b8e 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -28,7 +28,7 @@ impl PhysicalExpr for GatherExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?; @@ -129,8 +129,8 @@ impl GatherExpr { let groups = ac.groups(); // Determine the gather indices. - let idx: IdxCa = match groups.as_ref() { - GroupsProxy::Idx(groups) => { + let idx: IdxCa = match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { if groups.all().iter().zip(idx).any(|(g, idx)| match idx { None => false, Some(idx) => idx >= g.len() as IdxSize, @@ -149,7 +149,7 @@ impl GatherExpr { }) .collect_trusted() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if groups.iter().zip(idx).any(|(g, idx)| match idx { None => false, Some(idx) => idx >= g[1], @@ -207,8 +207,8 @@ impl GatherExpr { let groups = ac.groups(); // We offset the groups first by idx. - let idx: NoNull = match groups.as_ref() { - GroupsProxy::Idx(groups) => { + let idx: NoNull = match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { if groups.all().iter().any(|g| idx >= g.len() as IdxSize) { self.oob_err()?; } @@ -221,7 +221,7 @@ impl GatherExpr { }) .collect_trusted() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { if groups.iter().any(|g| idx >= g[1]) { self.oob_err()?; } @@ -251,7 +251,7 @@ impl GatherExpr { &self, mut ac: AggregationContext<'b>, mut idx: AggregationContext<'b>, - groups: &'b GroupsProxy, + groups: &'b GroupsType, ) -> PolarsResult> { let mut builder = get_list_builder( &ac.dtype(), diff --git a/crates/polars-expr/src/expressions/literal.rs b/crates/polars-expr/src/expressions/literal.rs index 8b152803bb64..f560ffdd1d8f 100644 --- a/crates/polars-expr/src/expressions/literal.rs +++ b/crates/polars-expr/src/expressions/literal.rs @@ -138,7 +138,7 @@ impl PhysicalExpr for LiteralExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let s = self.evaluate(df, state)?; @@ -168,7 +168,7 @@ impl PartitionedAggregation for LiteralExpr { fn evaluate_partitioned( &self, df: &DataFrame, - _groups: &GroupsProxy, + _groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { self.evaluate(df, state) @@ -177,7 +177,7 @@ impl PartitionedAggregation for LiteralExpr { fn finalize( &self, partitioned: Column, - _groups: &GroupsProxy, + _groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { Ok(partitioned) diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index c309991990ee..4fc9dab8c043 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -96,7 +96,7 @@ pub struct AggregationContext<'a> { /// 2. flat (still needs the grouptuples to aggregate) state: AggState, /// group tuples for AggState - groups: Cow<'a, GroupsProxy>, + groups: Cow<'a, GroupPositions>, /// if the group tuples are already used in a level above /// and the series is exploded, the group tuples are sorted /// e.g. the exploded Series is grouped per group. @@ -105,7 +105,7 @@ pub struct AggregationContext<'a> { /// into a sorted groups. We do this lazily, so that this work only is /// done when the groups are needed update_groups: UpdateGroups, - /// This is true when the Series and GroupsProxy still have all + /// This is true when the Series and Groups still have all /// their original values. Not the case when filtered original_len: bool, } @@ -119,7 +119,7 @@ impl<'a> AggregationContext<'a> { AggState::NotAggregated(s) => s.dtype().clone(), } } - pub(crate) fn groups(&mut self) -> &Cow<'a, GroupsProxy> { + pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> { match self.update_groups { UpdateGroups::No => {}, UpdateGroups::WithGroupsLen => { @@ -129,8 +129,8 @@ impl<'a> AggregationContext<'a> { // match the exploded Series let mut offset = 0 as IdxSize; - match self.groups.as_ref() { - GroupsProxy::Idx(groups) => { + match self.groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .iter() .map(|g| { @@ -141,13 +141,16 @@ impl<'a> AggregationContext<'a> { out }) .collect(); - self.groups = Cow::Owned(GroupsProxy::Slice { - groups, - rolling: false, - }) + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ) }, // sliced groups are already in correct order - GroupsProxy::Slice { .. } => {}, + GroupsType::Slice { .. } => {}, } self.update_groups = UpdateGroups::No; }, @@ -192,7 +195,7 @@ impl<'a> AggregationContext<'a> { /// the columns dtype) fn new( column: Column, - groups: Cow<'a, GroupsProxy>, + groups: Cow<'a, GroupPositions>, aggregated: bool, ) -> AggregationContext<'a> { let series = match (aggregated, column.dtype()) { @@ -220,7 +223,10 @@ impl<'a> AggregationContext<'a> { self.state = agg_state; } - fn from_agg_state(agg_state: AggState, groups: Cow<'a, GroupsProxy>) -> AggregationContext<'a> { + fn from_agg_state( + agg_state: AggState, + groups: Cow<'a, GroupPositions>, + ) -> AggregationContext<'a> { Self { state: agg_state, groups, @@ -230,7 +236,7 @@ impl<'a> AggregationContext<'a> { } } - fn from_literal(lit: Column, groups: Cow<'a, GroupsProxy>) -> AggregationContext<'a> { + fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> { Self { state: AggState::Literal(lit), groups, @@ -276,10 +282,13 @@ impl<'a> AggregationContext<'a> { out }) .collect_trusted(); - self.groups = Cow::Owned(GroupsProxy::Slice { - groups, - rolling: false, - }); + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ); }, _ => { let groups = { @@ -300,10 +309,13 @@ impl<'a> AggregationContext<'a> { }) .collect_trusted() }; - self.groups = Cow::Owned(GroupsProxy::Slice { - groups, - rolling: false, - }); + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ); }, } self.update_groups = UpdateGroups::No; @@ -370,7 +382,7 @@ impl<'a> AggregationContext<'a> { } /// Update the group tuples - pub(crate) fn with_groups(&mut self, groups: GroupsProxy) -> &mut Self { + pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self { if let AggState::AggregatedList(_) = self.agg_state() { // In case of new groups, a series always needs to be flattened self.with_values(self.flat_naive().into_owned(), false, None) @@ -452,7 +464,7 @@ impl<'a> AggregationContext<'a> { } } - pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupsProxy>) { + pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) { let _ = self.groups(); let groups = self.groups; match self.state { @@ -504,7 +516,7 @@ impl<'a> AggregationContext<'a> { { // panic so we find cases where we accidentally explode overlapping groups // we don't want this as this can create a lot of data - if let GroupsProxy::Slice { rolling: true, .. } = self.groups.as_ref() { + if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() { panic!("implementation error, polars should not hit this branch for overlapping groups") } } @@ -578,7 +590,7 @@ pub trait PhysicalExpr: Send + Sync { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult>; @@ -681,7 +693,7 @@ pub trait PartitionedAggregation: Send + Sync + PhysicalExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult; @@ -690,7 +702,7 @@ pub trait PartitionedAggregation: Send + Sync + PhysicalExpr { fn finalize( &self, partitioned: Column, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult; } diff --git a/crates/polars-expr/src/expressions/rolling.rs b/crates/polars-expr/src/expressions/rolling.rs index 2ec32069a30f..1b61318fac90 100644 --- a/crates/polars-expr/src/expressions/rolling.rs +++ b/crates/polars-expr/src/expressions/rolling.rs @@ -53,7 +53,7 @@ impl PhysicalExpr for RollingExpr { fn evaluate_on_groups<'a>( &self, _df: &DataFrame, - _groups: &'a GroupsProxy, + _groups: &'a GroupPositions, _state: &ExecutionState, ) -> PolarsResult> { polars_bail!(InvalidOperation: "rolling expression not allowed in aggregation"); diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs index 62df859460f8..f39fbade9f23 100644 --- a/crates/polars-expr/src/expressions/slice.rs +++ b/crates/polars-expr/src/expressions/slice.rs @@ -43,7 +43,7 @@ fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<( Ok((extract_offset(offset, expr)?, extract_length(length, expr)?)) } -fn check_argument(arg: &Column, groups: &GroupsProxy, name: &str, expr: &Expr) -> PolarsResult<()> { +fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> { polars_ensure!( !matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError: "invalid slice argument: cannot use an array as {} argument", name, @@ -100,7 +100,7 @@ impl PhysicalExpr for SliceExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut results = POOL.install(|| { @@ -130,20 +130,20 @@ impl PhysicalExpr for SliceExpr { } let groups = ac.groups(); - match groups.as_ref() { - GroupsProxy::Idx(groups) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .iter() .map(|(first, idx)| slice_groups_idx(offset, length, first, idx)) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .iter() .map(|&[first, len]| slice_groups_slice(offset, length, first, len)) .collect_trusted(); - GroupsProxy::Slice { + GroupsType::Slice { groups, rolling: false, } @@ -159,8 +159,8 @@ impl PhysicalExpr for SliceExpr { let length = length.cast(&IDX_DTYPE)?; let length = length.idx().unwrap(); - match groups.as_ref() { - GroupsProxy::Idx(groups) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .iter() .zip(length.into_no_null_iter()) @@ -168,9 +168,9 @@ impl PhysicalExpr for SliceExpr { slice_groups_idx(offset, length as usize, first, idx) }) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .iter() .zip(length.into_no_null_iter()) @@ -178,7 +178,7 @@ impl PhysicalExpr for SliceExpr { slice_groups_slice(offset, length as usize, first, len) }) .collect_trusted(); - GroupsProxy::Slice { + GroupsType::Slice { groups, rolling: false, } @@ -194,8 +194,8 @@ impl PhysicalExpr for SliceExpr { let offset = offset.cast(&DataType::Int64)?; let offset = offset.i64().unwrap(); - match groups.as_ref() { - GroupsProxy::Idx(groups) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .iter() .zip(offset.into_no_null_iter()) @@ -203,9 +203,9 @@ impl PhysicalExpr for SliceExpr { slice_groups_idx(offset, length, first, idx) }) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .iter() .zip(offset.into_no_null_iter()) @@ -213,7 +213,7 @@ impl PhysicalExpr for SliceExpr { slice_groups_slice(offset, length, first, len) }) .collect_trusted(); - GroupsProxy::Slice { + GroupsType::Slice { groups, rolling: false, } @@ -233,8 +233,8 @@ impl PhysicalExpr for SliceExpr { let length = length.cast(&IDX_DTYPE)?; let length = length.idx().unwrap(); - match groups.as_ref() { - GroupsProxy::Idx(groups) => { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { let groups = groups .iter() .zip(offset.into_no_null_iter()) @@ -243,9 +243,9 @@ impl PhysicalExpr for SliceExpr { slice_groups_idx(offset, length as usize, first, idx) }) .collect(); - GroupsProxy::Idx(groups) + GroupsType::Idx(groups) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .iter() .zip(offset.into_no_null_iter()) @@ -254,7 +254,7 @@ impl PhysicalExpr for SliceExpr { slice_groups_slice(offset, length as usize, first, len) }) .collect_trusted(); - GroupsProxy::Slice { + GroupsType::Slice { groups, rolling: false, } @@ -263,7 +263,8 @@ impl PhysicalExpr for SliceExpr { }, }; - ac.with_groups(groups).set_original_len(false); + ac.with_groups(groups.into_sliceable()) + .set_original_len(false); Ok(ac) } diff --git a/crates/polars-expr/src/expressions/sort.rs b/crates/polars-expr/src/expressions/sort.rs index 746978b760a9..ae4f8d13eab4 100644 --- a/crates/polars-expr/src/expressions/sort.rs +++ b/crates/polars-expr/src/expressions/sort.rs @@ -56,7 +56,7 @@ impl PhysicalExpr for SortExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?; @@ -72,8 +72,8 @@ impl PhysicalExpr for SortExpr { let mut sort_options = self.options; sort_options.multithreaded = false; let groups = POOL.install(|| { - match ac.groups().as_ref() { - GroupsProxy::Idx(groups) => { + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { groups .par_iter() .map(|(first, idx)| { @@ -86,7 +86,7 @@ impl PhysicalExpr for SortExpr { }) .collect() }, - GroupsProxy::Slice { groups, .. } => groups + GroupsType::Slice { groups, .. } => groups .par_iter() .map(|&[first, len]| { let group = series.slice(first as i64, len as usize); @@ -97,8 +97,8 @@ impl PhysicalExpr for SortExpr { .collect(), } }); - let groups = GroupsProxy::Idx(groups); - ac.with_groups(groups); + let groups = GroupsType::Idx(groups); + ac.with_groups(groups.into_sliceable()); }, } diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index b3df9ee2316a..fc3e2b4c00cf 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -46,7 +46,7 @@ fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec { static ERR_MSG: &str = "expressions in 'sort_by' produced a different number of groups"; -fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> { +fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> { polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| { a.len() == b.len() }), ComputeError: ERR_MSG); @@ -54,10 +54,10 @@ fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> { } pub(super) fn update_groups_sort_by( - groups: &GroupsProxy, + groups: &GroupsType, sort_by_s: &Series, options: &SortOptions, -) -> PolarsResult { +) -> PolarsResult { // Will trigger a gather for every group, so rechunk before. let sort_by_s = sort_by_s.rechunk(); let groups = POOL.install(|| { @@ -67,7 +67,7 @@ pub(super) fn update_groups_sort_by( .collect::>() })?; - Ok(GroupsProxy::Idx(groups)) + Ok(GroupsType::Idx(groups)) } fn sort_by_groups_single_by( @@ -293,7 +293,7 @@ impl PhysicalExpr for SortByExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; @@ -386,7 +386,7 @@ impl PhysicalExpr for SortByExpr { }) .collect::>() }); - GroupsProxy::Idx(groups?) + GroupsType::Idx(groups?) }; // If the rhs is already aggregated once, it is reordered by the @@ -396,7 +396,7 @@ impl PhysicalExpr for SortByExpr { ac_in.with_values(s.explode().unwrap(), false, None)?; } - ac_in.with_groups(groups); + ac_in.with_groups(groups.into_sliceable()); Ok(ac_in) } diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index e7ec666eda50..c96360006319 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -107,7 +107,7 @@ impl PhysicalExpr for TernaryExpr { fn evaluate_on_groups<'a>( &self, df: &DataFrame, - groups: &'a GroupsProxy, + groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { let op_mask = || self.predicate.evaluate_on_groups(df, groups, state); @@ -343,7 +343,7 @@ impl PartitionedAggregation for TernaryExpr { fn evaluate_partitioned( &self, df: &DataFrame, - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult { let truthy = self.truthy.as_partitioned_aggregator().unwrap(); @@ -361,7 +361,7 @@ impl PartitionedAggregation for TernaryExpr { fn finalize( &self, partitioned: Column, - _groups: &GroupsProxy, + _groups: &GroupPositions, _state: &ExecutionState, ) -> PolarsResult { Ok(partitioned) diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 4e61224c6dcd..2c804d11260e 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -62,13 +62,13 @@ impl WindowExpr { // groups are not changed, we can map by doing a standard arg_sort. if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) { let mut iter = 0..flattened.len() as IdxSize; - match ac.groups().as_ref() { - GroupsProxy::Idx(groups) => { + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { for g in groups.all() { idx_mapping.extend(g.iter().copied().zip(&mut iter)); } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { for &[first, len] in groups { idx_mapping.extend((first..first + len).zip(&mut iter)); } @@ -79,13 +79,13 @@ impl WindowExpr { // and sort by the old indexes else { let mut original_idx = Vec::with_capacity(out_column.len()); - match gb.get_groups() { - GroupsProxy::Idx(groups) => { + match gb.get_groups().as_ref() { + GroupsType::Idx(groups) => { for g in groups.all() { original_idx.extend_from_slice(g) } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { for &[first, len] in groups { original_idx.extend(first..first + len) } @@ -94,13 +94,13 @@ impl WindowExpr { let mut original_idx_iter = original_idx.iter().copied(); - match ac.groups().as_ref() { - GroupsProxy::Idx(groups) => { + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { for g in groups.all() { idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter)); } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { for &[first, len] in groups { idx_mapping.extend((first..first + len).zip(&mut original_idx_iter)); } @@ -333,7 +333,7 @@ impl WindowExpr { // no explicit aggregations, map over the groups //`(col("x").sum() * col("y")).over("groups")` (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => { - if let GroupsProxy::Slice { .. } = gb.get_groups() { + if let GroupsType::Slice { .. } = gb.get_groups().as_ref() { // Result can be directly exploded if the input was sorted. Ok(MapStrategy::Explode) } else { @@ -444,9 +444,10 @@ impl PhysicalExpr for WindowExpr { let order_by = order_by.evaluate(df, state)?; polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height()); groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)? + .into_sliceable() } - let out: PolarsResult = Ok(groups); + let out: PolarsResult = Ok(groups); out }; @@ -654,7 +655,7 @@ impl PhysicalExpr for WindowExpr { fn evaluate_on_groups<'a>( &self, _df: &DataFrame, - _groups: &'a GroupsProxy, + _groups: &'a GroupPositions, _state: &ExecutionState, ) -> PolarsResult> { polars_bail!(InvalidOperation: "window expression not allowed in aggregation"); @@ -690,7 +691,7 @@ fn cache_gb(gb: GroupBy, state: &ExecutionState, cache_key: &str) { /// Simple reducing aggregation can be set by the groups fn set_by_groups( s: &Column, - groups: &GroupsProxy, + groups: &GroupsType, len: usize, update_groups: bool, ) -> Option { @@ -714,7 +715,7 @@ fn set_by_groups( } } -fn set_numeric(ca: &ChunkedArray, groups: &GroupsProxy, len: usize) -> Series +fn set_numeric(ca: &ChunkedArray, groups: &GroupsType, len: usize) -> Series where T: PolarsNumericType, ChunkedArray: IntoSeries, @@ -728,7 +729,7 @@ where if ca.null_count() == 0 { let ca = ca.rechunk(); match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let agg_vals = ca.cont_slice().expect("rechunked"); POOL.install(|| { agg_vals @@ -743,7 +744,7 @@ where }) }) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let agg_vals = ca.cont_slice().expect("rechunked"); POOL.install(|| { agg_vals @@ -776,7 +777,7 @@ where let offsets = _split_offsets(ca.len(), n_threads); match groups { - GroupsProxy::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| { + GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| { let offset = *offset; let offset_len = *offset_len; let ca = ca.slice(offset as i64, offset_len); @@ -803,7 +804,7 @@ where } }) }), - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { offsets.par_iter().for_each(|(offset, offset_len)| { let offset = *offset; let offset_len = *offset_len; diff --git a/crates/polars-expr/src/state/execution_state.rs b/crates/polars-expr/src/state/execution_state.rs index 07c571e26653..3695cebcdfaa 100644 --- a/crates/polars-expr/src/state/execution_state.rs +++ b/crates/polars-expr/src/state/execution_state.rs @@ -11,7 +11,7 @@ use polars_ops::prelude::ChunkJoinOptIds; use super::NodeTimer; pub type JoinTuplesCache = Arc>>; -pub type GroupsProxyCache = Arc>>; +pub type GroupsTypeCache = Arc>>; bitflags! { #[repr(transparent)] @@ -63,7 +63,7 @@ pub struct ExecutionState { df_cache: Arc>>, pub schema_cache: RwLock>, /// Used by Window Expression to prevent redundant grouping - pub group_tuples: GroupsProxyCache, + pub group_tuples: GroupsTypeCache, /// Used by Window Expression to prevent redundant joins pub join_tuples: JoinTuplesCache, // every join/union split gets an increment to distinguish between schema state diff --git a/crates/polars-io/src/partition.rs b/crates/polars-io/src/partition.rs index bf4af7ca9818..824b38c7f9ef 100644 --- a/crates/polars-io/src/partition.rs +++ b/crates/polars-io/src/partition.rs @@ -160,8 +160,8 @@ pub fn write_partitioned_dataset( } }; - POOL.install(|| match groups { - GroupsProxy::Idx(idx) => idx + POOL.install(|| match groups.as_ref() { + GroupsType::Idx(idx) => idx .all() .chunks(MAX_OPEN_FILES) .map(|chunk| { @@ -179,7 +179,7 @@ pub fn write_partitioned_dataset( ) }) .collect::>>(), - GroupsProxy::Slice { groups, .. } => groups + GroupsType::Slice { groups, .. } => groups .chunks(MAX_OPEN_FILES) .map(|chunk| { chunk diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index cd8e66e9bfd0..09b99bf681ec 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -21,7 +21,7 @@ impl IntoListNameSpace for ListNameSpace { } } -fn offsets_to_groups(offsets: &[i64]) -> Option { +fn offsets_to_groups(offsets: &[i64]) -> Option { let mut start = offsets[0]; let end = *offsets.last().unwrap(); if IdxSize::try_from(end - start).is_err() { @@ -37,10 +37,13 @@ fn offsets_to_groups(offsets: &[i64]) -> Option { [offset, len] }) .collect(); - Some(GroupsProxy::Slice { - groups, - rolling: false, - }) + Some( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ) } fn run_per_sublist( diff --git a/crates/polars-lazy/src/frame/pivot.rs b/crates/polars-lazy/src/frame/pivot.rs index 70eed4d8f58c..6225c1667bbb 100644 --- a/crates/polars-lazy/src/frame/pivot.rs +++ b/crates/polars-lazy/src/frame/pivot.rs @@ -18,7 +18,7 @@ use crate::prelude::*; struct PivotExpr(Expr); impl PhysicalAggExpr for PivotExpr { - fn evaluate(&self, df: &DataFrame, groups: &GroupsProxy) -> PolarsResult { + fn evaluate(&self, df: &DataFrame, groups: &GroupPositions) -> PolarsResult { let state = ExecutionState::new(); let dtype = df.get_columns()[0].dtype(); let phys_expr = prepare_expression_for_context( diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 2ab337ef51e9..7180efd5b587 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -9,7 +9,7 @@ fn test_agg_list_type() -> PolarsResult<()> { let s = Series::new("foo".into(), &[1, 2, 3]); let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; - let l = unsafe { s.agg_list(&GroupsProxy::Idx(vec![(0, unitvec![0, 1, 2])].into())) }; + let l = unsafe { s.agg_list(&GroupsType::Idx(vec![(0, unitvec![0, 1, 2])].into())) }; let result = match l.dtype() { DataType::List(inner) => { diff --git a/crates/polars-mem-engine/src/executors/group_by.rs b/crates/polars-mem-engine/src/executors/group_by.rs index 09dcae659fee..fec04ea5f3b9 100644 --- a/crates/polars-mem-engine/src/executors/group_by.rs +++ b/crates/polars-mem-engine/src/executors/group_by.rs @@ -5,7 +5,7 @@ use super::*; pub(super) fn evaluate_aggs( df: &DataFrame, aggs: &[Arc], - groups: &GroupsProxy, + groups: &GroupPositions, state: &ExecutionState, ) -> PolarsResult> { POOL.install(|| { @@ -78,7 +78,7 @@ pub(super) fn group_by_helper( if let Some((offset, len)) = slice { sliced_groups = Some(groups.slice(offset, len)); - groups = sliced_groups.as_deref().unwrap(); + groups = sliced_groups.as_ref().unwrap(); } let (mut columns, agg_columns) = POOL.install(|| { diff --git a/crates/polars-mem-engine/src/executors/group_by_dynamic.rs b/crates/polars-mem-engine/src/executors/group_by_dynamic.rs index 2a3fe16c08cd..5dfa8328962a 100644 --- a/crates/polars-mem-engine/src/executors/group_by_dynamic.rs +++ b/crates/polars-mem-engine/src/executors/group_by_dynamic.rs @@ -47,7 +47,7 @@ impl GroupByDynamicExec { if let Some((offset, len)) = self.slice { sliced_groups = Some(groups.slice(offset, len)); - groups = sliced_groups.as_deref().unwrap(); + groups = sliced_groups.as_ref().unwrap(); time_key = time_key.slice(offset, len); diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs index ad5d647fdb33..d8416a503b34 100644 --- a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -129,7 +129,7 @@ fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResul sample_size = set_size; } - let finish = |groups: &GroupsProxy| { + let finish = |groups: &GroupsType| { let u = groups.len() as f64; let ui = if groups.len() == sample_size { u @@ -316,7 +316,7 @@ impl PartitionGroupByExec { if let Some((offset, len)) = self.slice { sliced_groups = Some(groups.slice(offset, len)); - groups = sliced_groups.as_deref().unwrap(); + groups = sliced_groups.as_ref().unwrap(); } let get_columns = || gb.keys_sliced(self.slice); diff --git a/crates/polars-mem-engine/src/executors/group_by_rolling.rs b/crates/polars-mem-engine/src/executors/group_by_rolling.rs index 5d9068f13de7..d68bddeb7386 100644 --- a/crates/polars-mem-engine/src/executors/group_by_rolling.rs +++ b/crates/polars-mem-engine/src/executors/group_by_rolling.rs @@ -13,9 +13,9 @@ pub(crate) struct GroupByRollingExec { } #[cfg(feature = "dynamic_group_by")] -unsafe fn update_keys(keys: &mut [Column], groups: &GroupsProxy) { +unsafe fn update_keys(keys: &mut [Column], groups: &GroupsType) { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let first = groups.first(); // we don't use agg_first here, because the group // can be empty, but we still want to know the first value @@ -24,7 +24,7 @@ unsafe fn update_keys(keys: &mut [Column], groups: &GroupsProxy) { *key = key.take_slice_unchecked(first); } }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { for key in keys.iter_mut() { let indices = groups .iter() @@ -70,7 +70,7 @@ impl GroupByRollingExec { if let Some((offset, len)) = self.slice { sliced_groups = Some(groups.slice(offset, len)); - groups = sliced_groups.as_deref().unwrap(); + groups = sliced_groups.as_ref().unwrap(); time_key = time_key.slice(offset, len); } diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index 87b5fc5e5134..0ff9b54e7ff8 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -3,7 +3,7 @@ use polars_core::{with_match_physical_integer_polars_type, POOL}; fn mode_primitive(ca: &ChunkedArray) -> PolarsResult> where - ChunkedArray: IntoGroupsProxy + ChunkTake<[IdxSize]>, + ChunkedArray: IntoGroupsType + ChunkTake<[IdxSize]>, { if ca.is_empty() { return Ok(ca.clone()); @@ -29,9 +29,9 @@ fn mode_64(ca: &Float64Chunked) -> PolarsResult { Ok(ca) } -fn mode_indices(groups: GroupsProxy) -> Vec { +fn mode_indices(groups: GroupsType) -> Vec { match groups { - GroupsProxy::Idx(groups) => { + GroupsType::Idx(groups) => { let Some(max_len) = groups.iter().map(|g| g.1.len()).max() else { return Vec::new(); }; @@ -41,7 +41,7 @@ fn mode_indices(groups: GroupsProxy) -> Vec { .map(|g| g.0) .collect() }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let Some(max_len) = groups.iter().map(|g| g[1]).max() else { return Vec::new(); }; diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs index ec1d8b2c9d4f..19c49492cd68 100644 --- a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -58,13 +58,13 @@ pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series { } } -unsafe fn group_nan_max(ca: &ChunkedArray, groups: &GroupsProxy) -> Series +unsafe fn group_nan_max(ca: &ChunkedArray, groups: &GroupsType) -> Series where T: PolarsFloatType, ChunkedArray: IntoSeries, { match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { + GroupsType::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { debug_assert!(idx.len() <= ca.len()); if idx.is_empty() { None @@ -89,7 +89,7 @@ where } } }), - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -127,13 +127,13 @@ where } } -unsafe fn group_nan_min(ca: &ChunkedArray, groups: &GroupsProxy) -> Series +unsafe fn group_nan_min(ca: &ChunkedArray, groups: &GroupsType) -> Series where T: PolarsFloatType, ChunkedArray: IntoSeries, { match groups { - GroupsProxy::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { + GroupsType::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { debug_assert!(idx.len() <= ca.len()); if idx.is_empty() { None @@ -158,7 +158,7 @@ where } } }), - GroupsProxy::Slice { + GroupsType::Slice { groups: groups_slice, .. } => { @@ -198,7 +198,7 @@ where /// # Safety /// `groups` must be in bounds. -pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsProxy) -> Series { +pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series { match s.dtype() { DataType::Float32 => { let ca = s.f32().unwrap(); @@ -214,7 +214,7 @@ pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsProxy) -> Series { /// # Safety /// `groups` must be in bounds. -pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsProxy) -> Series { +pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series { match s.dtype() { DataType::Float32 => { let ca = s.f32().unwrap(); diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index a81f6625c057..c3a105ba0439 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -230,7 +230,7 @@ where pub(super) fn compute_col_idx( pivot_df: &DataFrame, column: &str, - groups: &GroupsProxy, + groups: &GroupsType, ) -> PolarsResult<(Vec, Column)> { let column_s = pivot_df.column(column)?; let column_agg = unsafe { column_s.agg_first(groups) }; @@ -401,7 +401,7 @@ fn compute_row_index_struct( pub(super) fn compute_row_idx( pivot_df: &DataFrame, index: &[PlSmallStr], - groups: &GroupsProxy, + groups: &GroupsType, count: usize, ) -> PolarsResult<(Vec, usize, Option>)> { let (row_locations, n_rows, row_index) = if index.len() == 1 { diff --git a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs index 48a2f944b87d..1d86da5cffb9 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs @@ -243,7 +243,7 @@ fn partition_df( let partitions = partitions.idx().unwrap().clone(); let out = match groups { - GroupsProxy::Idx(idx) => { + GroupsType::Idx(idx) => { let iter = idx.into_iter().map(move |(_, group)| { // groups are in bounds and sorted unsafe { @@ -252,7 +252,7 @@ fn partition_df( }); Box::new(iter) as DfIter }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let iter = groups .into_iter() .map(move |[first, len]| df.slice(first as i64, len as usize)); diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 2ee6b6a0a793..8a60ba753463 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -86,13 +86,13 @@ pub trait PolarsTemporalGroupby { &self, group_by: Vec, options: &RollingGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)>; + ) -> PolarsResult<(Column, Vec, GroupPositions)>; fn group_by_dynamic( &self, group_by: Vec, options: &DynamicGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)>; + ) -> PolarsResult<(Column, Vec, GroupPositions)>; } impl PolarsTemporalGroupby for DataFrame { @@ -100,7 +100,7 @@ impl PolarsTemporalGroupby for DataFrame { &self, group_by: Vec, options: &RollingGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { Wrap(self).rolling(group_by, options) } @@ -108,7 +108,7 @@ impl PolarsTemporalGroupby for DataFrame { &self, group_by: Vec, options: &DynamicGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { Wrap(self).group_by_dynamic(group_by, options) } } @@ -118,7 +118,7 @@ impl Wrap<&DataFrame> { &self, group_by: Vec, options: &RollingGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { polars_ensure!( !options.period.is_zero() && !options.period.negative, ComputeError: @@ -192,7 +192,7 @@ impl Wrap<&DataFrame> { &self, group_by: Vec, options: &DynamicGroupOptions, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { let time = self.0.column(&options.index_column)?.rechunk(); if group_by.is_empty() { // If by is given, the column must be sorted in the 'by' arg, which we can not check now @@ -266,10 +266,10 @@ impl Wrap<&DataFrame> { options: &DynamicGroupOptions, tu: TimeUnit, time_type: &DataType, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { polars_ensure!(!options.every.negative, ComputeError: "'every' argument must be positive"); if dt.is_empty() { - return dt.cast(time_type).map(|s| (s, by, GroupsProxy::default())); + return dt.cast(time_type).map(|s| (s, by, Default::default())); } // A requirement for the index so we can set this such that downstream code has this info. @@ -322,7 +322,7 @@ impl Wrap<&DataFrame> { options.start_by, ); update_bounds(lower, upper); - PolarsResult::Ok(GroupsProxy::Slice { + PolarsResult::Ok(GroupsType::Slice { groups, rolling: false, }) @@ -334,8 +334,8 @@ impl Wrap<&DataFrame> { // Include boundaries cannot be parallel (easily). if include_lower_bound | include_upper_bound { - POOL.install(|| match groups { - GroupsProxy::Idx(groups) => { + POOL.install(|| match groups.as_ref() { + GroupsType::Idx(groups) => { let ir = groups .par_iter() .map(|base_g| { @@ -370,9 +370,9 @@ impl Wrap<&DataFrame> { }); // then parallelize the flatten in the `from` impl - Ok(GroupsProxy::Idx(GroupsIdx::from(groups))) + Ok(GroupsType::Idx(GroupsIdx::from(groups))) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let mut ir = groups .par_iter() .map(|base_g| { @@ -404,15 +404,15 @@ impl Wrap<&DataFrame> { let mut groups = Vec::with_capacity(capacity); ir.iter().for_each(|(_, _, g)| groups.extend_from_slice(g)); - Ok(GroupsProxy::Slice { + Ok(GroupsType::Slice { groups, rolling: false, }) }, }) } else { - POOL.install(|| match groups { - GroupsProxy::Idx(groups) => { + POOL.install(|| match groups.as_ref() { + GroupsType::Idx(groups) => { let groupsidx = groups .par_iter() .map(|base_g| { @@ -435,9 +435,9 @@ impl Wrap<&DataFrame> { Ok(update_subgroups_idx(&sub_groups, base_g)) }) .collect::>>()?; - Ok(GroupsProxy::Idx(GroupsIdx::from(groupsidx))) + Ok(GroupsType::Idx(GroupsIdx::from(groupsidx))) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let groups = groups .par_iter() .map(|base_g| { @@ -460,7 +460,7 @@ impl Wrap<&DataFrame> { let groups = flatten_par(&groups); - Ok(GroupsProxy::Slice { + Ok(GroupsType::Slice { groups, rolling: false, }) @@ -509,7 +509,7 @@ impl Wrap<&DataFrame> { dt.into_datetime(tu, None) .into_column() .cast(time_type) - .map(|s| (s, by, groups)) + .map(|s| (s, by, groups.into_sliceable())) } /// Returns: time_keys, keys, groupsproxy @@ -521,7 +521,7 @@ impl Wrap<&DataFrame> { tu: TimeUnit, tz: Option, time_type: &DataType, - ) -> PolarsResult<(Column, Vec, GroupsProxy)> { + ) -> PolarsResult<(Column, Vec, GroupPositions)> { let mut dt = dt.rechunk(); let groups = if group_by.is_empty() { @@ -531,7 +531,7 @@ impl Wrap<&DataFrame> { let dt = dt.datetime().unwrap(); let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); - PolarsResult::Ok(GroupsProxy::Slice { + PolarsResult::Ok(GroupsType::Slice { groups: group_by_values( options.period, options.offset, @@ -556,8 +556,8 @@ impl Wrap<&DataFrame> { // continue determining the rolling indexes. - POOL.install(|| match groups { - GroupsProxy::Idx(groups) => { + POOL.install(|| match groups.as_ref() { + GroupsType::Idx(groups) => { let idx = groups .par_iter() .map(|base_g| { @@ -580,9 +580,9 @@ impl Wrap<&DataFrame> { }) .collect::>>()?; - Ok(GroupsProxy::Idx(GroupsIdx::from(idx))) + Ok(GroupsType::Idx(GroupsIdx::from(idx))) }, - GroupsProxy::Slice { groups, .. } => { + GroupsType::Slice { groups, .. } => { let slice_groups = groups .par_iter() .map(|base_g| { @@ -602,7 +602,7 @@ impl Wrap<&DataFrame> { .collect::>>()?; let slice_groups = flatten_par(&slice_groups); - Ok(GroupsProxy::Slice { + Ok(GroupsType::Slice { groups: slice_groups, rolling: false, }) @@ -612,7 +612,7 @@ impl Wrap<&DataFrame> { let dt = dt.cast(time_type).unwrap(); - Ok((dt, group_by, groups)) + Ok((dt, group_by, groups.into_sliceable())) } } @@ -906,7 +906,7 @@ mod test { .into_column(); assert_eq!(&upper, &range); - let expected = GroupsProxy::Idx( + let expected = GroupsType::Idx( vec![ (0 as IdxSize, unitvec![0 as IdxSize, 1, 2]), (2, unitvec![2]), @@ -916,7 +916,8 @@ mod test { (4, unitvec![4]), ] .into(), - ); + ) + .into_sliceable(); assert_eq!(expected, groups); Ok(()) }