diff --git a/polars/polars-ops/src/chunked_array/list/mod.rs b/polars/polars-ops/src/chunked_array/list/mod.rs index 295fb82200e4..3c58fe4ee88b 100644 --- a/polars/polars-ops/src/chunked_array/list/mod.rs +++ b/polars/polars-ops/src/chunked_array/list/mod.rs @@ -4,7 +4,7 @@ mod count; #[cfg(feature = "hash")] pub(crate) mod hash; mod namespace; -mod sum; +mod sum_mean; #[cfg(feature = "list_to_struct")] mod to_struct; diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index eac25f89ea70..52bd08f91c80 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -13,7 +13,7 @@ use polars_core::series::ops::NullBehavior; use polars_core::utils::{try_get_supertype, CustomIterTools}; use super::*; -use crate::prelude::list::sum::sum_list_numerical; +use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical}; use crate::series::ArgAgg; fn has_inner_nulls(ca: &ListChunked) -> bool { @@ -154,11 +154,39 @@ pub trait ListNameSpaceImpl: AsList { } } - fn lst_mean(&self) -> Float64Chunked { + fn lst_mean(&self) -> Series { + fn inner(ca: &ListChunked) -> Series { + let mut out: Float64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().mean())) + .collect(); + + out.rename(ca.name()); + out.into_series() + } + use DataType::*; + let ca = self.as_list(); - ca.amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().mean())) - .collect() + + if has_inner_nulls(ca) { + return match ca.inner_dtype() { + Float32 => { + let mut out: Float32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32))) + .collect(); + + out.rename(ca.name()); + out.into_series() + } + _ => inner(ca), + }; + }; + + match ca.inner_dtype() { + dt if dt.is_numeric() => mean_list_numerical(ca, &dt), + _ => inner(ca), + } } #[must_use] diff --git a/polars/polars-ops/src/chunked_array/list/sum.rs b/polars/polars-ops/src/chunked_array/list/sum.rs deleted file mode 100644 index 6a2d544018a5..000000000000 --- a/polars/polars-ops/src/chunked_array/list/sum.rs +++ /dev/null @@ -1,83 +0,0 @@ -use arrow::array::{Array, PrimitiveArray}; -use arrow::bitmap::Bitmap; -use arrow::types::NativeType; -use polars_arrow::utils::CustomIterTools; -use polars_core::datatypes::ListChunked; -use polars_core::export::num::{NumCast, ToPrimitive}; -use polars_utils::unwrap::UnwrapUncheckedRelease; - -use super::*; - -fn sum_slice(values: &[T]) -> S -where - T: NativeType + ToPrimitive, - S: NumCast + std::iter::Sum, -{ - values - .iter() - .copied() - .map(|t| unsafe { - let s: S = NumCast::from(t).unwrap_unchecked_release(); - s - }) - .sum() -} - -fn sum_between_offsets(values: &[T], offset: &[i64]) -> Vec -where - T: NativeType + ToPrimitive, - S: NumCast + std::iter::Sum, -{ - let mut running_offset = offset[0]; - - (offset[1..]) - .iter() - .map(|end| { - let current_offset = running_offset; - running_offset = *end; - - let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; - sum_slice(slice) - }) - .collect_trusted() -} - -fn dispatch(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef -where - T: NativeType + ToPrimitive, - S: NativeType + NumCast + std::iter::Sum, -{ - let values = arr.as_any().downcast_ref::>().unwrap(); - let values = values.values().as_slice(); - Box::new(PrimitiveArray::from_data_default( - sum_between_offsets::<_, S>(values, offsets).into(), - validity.cloned(), - )) as ArrayRef -} - -pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { - use DataType::*; - let chunks = ca - .downcast_iter() - .map(|arr| { - let offsets = arr.offsets().as_slice(); - let values = arr.values().as_ref(); - - match inner_type { - Int8 => dispatch::(values, offsets, arr.validity()), - Int16 => dispatch::(values, offsets, arr.validity()), - Int32 => dispatch::(values, offsets, arr.validity()), - Int64 => dispatch::(values, offsets, arr.validity()), - UInt8 => dispatch::(values, offsets, arr.validity()), - UInt16 => dispatch::(values, offsets, arr.validity()), - UInt32 => dispatch::(values, offsets, arr.validity()), - UInt64 => dispatch::(values, offsets, arr.validity()), - Float32 => dispatch::(values, offsets, arr.validity()), - Float64 => dispatch::(values, offsets, arr.validity()), - _ => unimplemented!(), - } - }) - .collect::>(); - - Series::try_from((ca.name(), chunks)).unwrap() -} diff --git a/polars/polars-ops/src/chunked_array/list/sum_mean.rs b/polars/polars-ops/src/chunked_array/list/sum_mean.rs new file mode 100644 index 000000000000..fee8b97a3249 --- /dev/null +++ b/polars/polars-ops/src/chunked_array/list/sum_mean.rs @@ -0,0 +1,146 @@ +use std::ops::Div; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::types::NativeType; +use polars_arrow::utils::CustomIterTools; +use polars_core::datatypes::ListChunked; +use polars_core::export::num::{NumCast, ToPrimitive}; +use polars_utils::unwrap::UnwrapUncheckedRelease; + +use super::*; + +fn sum_slice(values: &[T]) -> S +where + T: NativeType + ToPrimitive, + S: NumCast + std::iter::Sum, +{ + values + .iter() + .copied() + .map(|t| unsafe { + let s: S = NumCast::from(t).unwrap_unchecked_release(); + s + }) + .sum() +} + +fn sum_between_offsets(values: &[T], offset: &[i64]) -> Vec +where + T: NativeType + ToPrimitive, + S: NumCast + std::iter::Sum, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + sum_slice(slice) + }) + .collect_trusted() +} + +fn dispatch_sum(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + Box::new(PrimitiveArray::from_data_default( + sum_between_offsets::<_, S>(values, offsets).into(), + validity.cloned(), + )) as ArrayRef +} + +pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_sum::(values, offsets, arr.validity()), + Int16 => dispatch_sum::(values, offsets, arr.validity()), + Int32 => dispatch_sum::(values, offsets, arr.validity()), + Int64 => dispatch_sum::(values, offsets, arr.validity()), + UInt8 => dispatch_sum::(values, offsets, arr.validity()), + UInt16 => dispatch_sum::(values, offsets, arr.validity()), + UInt32 => dispatch_sum::(values, offsets, arr.validity()), + UInt64 => dispatch_sum::(values, offsets, arr.validity()), + Float32 => dispatch_sum::(values, offsets, arr.validity()), + Float64 => dispatch_sum::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name(), chunks)).unwrap() +} + +fn mean_between_offsets(values: &[T], offset: &[i64]) -> Vec +where + T: NativeType + ToPrimitive, + S: NumCast + std::iter::Sum + Div, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + unsafe { + sum_slice::<_, S>(slice) / NumCast::from(slice.len()).unwrap_unchecked_release() + } + }) + .collect_trusted() +} + +fn dispatch_mean(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum + Div, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + Box::new(PrimitiveArray::from_data_default( + mean_between_offsets::<_, S>(values, offsets).into(), + validity.cloned(), + )) as ArrayRef +} + +pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_mean::(values, offsets, arr.validity()), + Int16 => dispatch_mean::(values, offsets, arr.validity()), + Int32 => dispatch_mean::(values, offsets, arr.validity()), + Int64 => dispatch_mean::(values, offsets, arr.validity()), + UInt8 => dispatch_mean::(values, offsets, arr.validity()), + UInt16 => dispatch_mean::(values, offsets, arr.validity()), + UInt32 => dispatch_mean::(values, offsets, arr.validity()), + UInt64 => dispatch_mean::(values, offsets, arr.validity()), + Float32 => dispatch_mean::(values, offsets, arr.validity()), + Float64 => dispatch_mean::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name(), chunks)).unwrap() +} diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 1faf1f71064e..bdbe14ed929d 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -405,3 +405,13 @@ def test_list_sum_and_dtypes() -> None: assert pl.DataFrame( {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]} ).select(pl.col("a").arr.sum()).to_dict(False) == {"a": [1, 6, 10, 15, None]} + + +def test_list_mean() -> None: + assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}).select( + pl.col("a").arr.mean() + ).to_dict(False) == {"a": [1.0, 2.0, 2.5, 3.0]} + + assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], None]}).select( + pl.col("a").arr.mean() + ).to_dict(False) == {"a": [1.0, 2.0, 2.5, None]}