From 56bbd572dc483e9cbc6d15cda0a71f9482820629 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 20 Feb 2023 19:34:41 +0100 Subject: [PATCH] perf(rust, python): optimize arr.sum for list array with inner nulls (#7053) --- .../src/chunked_array/list/namespace.rs | 74 ++++++++++++++++--- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index b742bcf7cb5b..bbd81b2958f8 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -126,24 +126,80 @@ pub trait ListNameSpaceImpl: AsList { } fn lst_sum(&self) -> Series { - fn inner(ca: &ListChunked) -> Series { - ca.apply_amortized(|s| s.as_ref().sum_as_series()) - .explode() - .unwrap() - .into_series() + fn inner(ca: &ListChunked, inner_dtype: &DataType) -> Series { + use DataType::*; + // TODO: add fast path for smaller ints? + let mut out = match inner_dtype { + Boolean => { + let out: IdxCa = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + UInt32 => { + let out: UInt32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + UInt64 => { + let out: UInt64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Int32 => { + let out: Int32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Int64 => { + let out: Int64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Float32 => { + let out: Float32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + Float64 => { + let out: Float64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum())) + .collect(); + out.into_series() + } + // slowest sum_as_series path + _ => ca + .apply_amortized(|s| s.as_ref().sum_as_series()) + .explode() + .unwrap() + .into_series(), + }; + out.rename(ca.name()); + out } let ca = self.as_list(); if has_inner_nulls(ca) { - return inner(ca); + return inner(ca, &ca.inner_dtype()); }; - use DataType::*; match ca.inner_dtype() { - Boolean => count_boolean_bits(ca).into_series(), + DataType::Boolean => count_boolean_bits(ca).into_series(), dt if dt.is_numeric() => sum_list_numerical(ca, &dt), - _ => inner(ca), + dt => inner(ca, &dt), } }