Skip to content

Commit

Permalink
perf(rust, python): optimize arr.sum for list array with inner nulls (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 20, 2023
1 parent 8ff6c9d commit 56bbd57
Showing 1 changed file with 65 additions and 9 deletions.
74 changes: 65 additions & 9 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,80 @@ pub trait ListNameSpaceImpl: AsList {
}

fn lst_sum(&self) -> Series {
fn inner(ca: &ListChunked) -> Series {
ca.apply_amortized(|s| s.as_ref().sum_as_series())
.explode()
.unwrap()
.into_series()
fn inner(ca: &ListChunked, inner_dtype: &DataType) -> Series {
use DataType::*;
// TODO: add fast path for smaller ints?
let mut out = match inner_dtype {
Boolean => {
let out: IdxCa = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
UInt32 => {
let out: UInt32Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
UInt64 => {
let out: UInt64Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
Int32 => {
let out: Int32Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
Int64 => {
let out: Int64Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
Float32 => {
let out: Float32Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
Float64 => {
let out: Float64Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum()))
.collect();
out.into_series()
}
// slowest sum_as_series path
_ => ca
.apply_amortized(|s| s.as_ref().sum_as_series())
.explode()
.unwrap()
.into_series(),
};
out.rename(ca.name());
out
}

let ca = self.as_list();

if has_inner_nulls(ca) {
return inner(ca);
return inner(ca, &ca.inner_dtype());
};

use DataType::*;
match ca.inner_dtype() {
Boolean => count_boolean_bits(ca).into_series(),
DataType::Boolean => count_boolean_bits(ca).into_series(),
dt if dt.is_numeric() => sum_list_numerical(ca, &dt),
_ => inner(ca),
dt => inner(ca, &dt),
}
}

Expand Down

0 comments on commit 56bbd57

Please sign in to comment.