From 9a05fc7458b12d687aed58441b615a4fe205508c Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 11 Jun 2024 08:41:37 +0200 Subject: [PATCH] perf: Use `split_at` in `split` --- crates/polars-core/src/utils/mod.rs | 36 +++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index fd38c4bad099..854e60ba4d7d 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -199,19 +199,24 @@ impl Container for Series { } fn split_impl(container: &C, target: usize, chunk_size: usize) -> Vec { - let total_len = container.len(); + if target == 1 { + return vec![container.clone()]; + } let mut out = Vec::with_capacity(target); + let chunk_size = chunk_size as i64; - for i in 0..target { - let offset = i * chunk_size; - let len = if i == (target - 1) { - total_len.saturating_sub(offset) - } else { - chunk_size - }; - let container = container.slice((i * chunk_size) as i64, len); - out.push(container); + // First split + let (chunk, mut remainder) = container.split_at(chunk_size); + out.push(chunk); + + // Take the rest of the splits of exactly chunk size, but skip the last remainder as we won't split that. + for _ in 1..target - 1 { + let (a, b) = remainder.split_at(chunk_size); + out.push(a); + remainder = b } + // This can be slightly larger than `chunk_size`, but is smaller than `2 * chunk_size`. + out.push(remainder); out } @@ -223,6 +228,7 @@ pub fn split(container: &C, target: usize) -> Vec { } let chunk_size = std::cmp::max(total_len / target, 1); + if container.n_chunks() == target && container .chunk_lengths() @@ -1156,6 +1162,16 @@ pub fn coalesce_nulls_series(a: &Series, b: &Series) -> (Series, Series) { mod test { use super::*; + #[test] + fn test_split() { + let ca: Int32Chunked = (0..10).collect_ca("a"); + + let out = split(&ca, 3); + assert_eq!(out[0].len(), 3); + assert_eq!(out[1].len(), 3); + assert_eq!(out[2].len(), 4); + } + #[test] fn test_align_chunks() { let a = Int32Chunked::new("", &[1, 2, 3, 4]);