diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index b742bcf7cb5b..bbd81b2958f8 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -126,24 +126,80 @@ pub trait ListNameSpaceImpl: AsList { } fn lst_sum(&self) -> Series { - fn inner(ca: &ListChunked) -> Series { - ca.apply_amortized(|s| s.as_ref().sum_as_series()) - .explode() - .unwrap() - .into_series() + fn inner(ca: &ListChunked, inner_dtype: &DataType) -> Series { + use DataType::*; + // TODO: add fast path for smaller ints? + let mut out = match inner_dtype { + Boolean => { + let out: IdxCa = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + UInt32 => { + let out: UInt32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + UInt64 => { + let out: UInt64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Int32 => { + let out: Int32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Int64 => { + let out: Int64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Float32 => { + let out: Float32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Float64 => { + let out: Float64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + // slowest sum_as_series path + _ => ca + .apply_amortized(|s| s.as_ref().sum_as_series()) + .explode() + .unwrap() + .into_series(), + }; + out.rename(ca.name()); + out } let ca = self.as_list(); if has_inner_nulls(ca) { - return inner(ca); + return inner(ca, &ca.inner_dtype()); }; - use DataType::*; match ca.inner_dtype() { - Boolean => count_boolean_bits(ca).into_series(), + DataType::Boolean => count_boolean_bits(ca).into_series(), dt if dt.is_numeric() => sum_list_numerical(ca, &dt), - _ => inner(ca), + dt => inner(ca, &dt), } }