From f14570abaa7c6d07c3d5ea064a63f6b1ee0fdb63 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 22 May 2022 19:02:47 +0200 Subject: [PATCH] improve `rolling_var` performance (#3470) * improve rolling_var performance no_nulls * improve performance rolling_variance with nulls --- .github/workflows/test-windows-python.yaml | 3 +- .../polars-arrow/src/kernels/rolling/mod.rs | 1 - .../src/kernels/rolling/no_nulls/mean.rs | 2 +- .../src/kernels/rolling/no_nulls/mod.rs | 105 +-------- .../src/kernels/rolling/no_nulls/variance.rs | 206 ++++++++++++++++++ .../src/kernels/rolling/nulls/mean.rs | 11 +- .../src/kernels/rolling/nulls/mod.rs | 200 ++++------------- .../src/kernels/rolling/nulls/sum.rs | 8 +- .../src/kernels/rolling/nulls/variance.rs | 203 +++++++++++++++++ .../src/chunked_array/ops/rolling_window.rs | 2 +- py-polars/tests/test_struct.py | 17 +- 11 files changed, 475 insertions(+), 283 deletions(-) create mode 100644 polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs create mode 100644 polars/polars-arrow/src/kernels/rolling/nulls/variance.rs diff --git a/.github/workflows/test-windows-python.yaml b/.github/workflows/test-windows-python.yaml index 8534ad2ef8f3..49439be93c49 100644 --- a/.github/workflows/test-windows-python.yaml +++ b/.github/workflows/test-windows-python.yaml @@ -15,7 +15,7 @@ jobs: with: toolchain: nightly-2022-05-20 override: true - components: rustfmt, clippy + components: rustfmt - name: Set up Python uses: actions/setup-python@v3 with: @@ -29,7 +29,6 @@ jobs: run: | export RUSTFLAGS="-C debuginfo=0" cd py-polars && rustup override set nightly-2022-05-01 && make build-and-test-no-venv - cargo clippy # test if we can import polars without any requirements - name: Import polars run: | diff --git a/polars/polars-arrow/src/kernels/rolling/mod.rs b/polars/polars-arrow/src/kernels/rolling/mod.rs index 3f00b9204018..80c04585c4ae 100644 --- a/polars/polars-arrow/src/kernels/rolling/mod.rs +++ b/polars/polars-arrow/src/kernels/rolling/mod.rs @@ -6,7 +6,6 @@ use crate::data_types::IsFloat; use crate::prelude::QuantileInterpolOptions; use crate::utils::CustomIterTools; use arrow::array::{ArrayRef, PrimitiveArray}; -use arrow::bitmap::utils::{count_zeros, get_bit_unchecked}; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::types::NativeType; use num::ToPrimitive; diff --git a/polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs b/polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs index faf76a85f094..8194fd0c3013 100644 --- a/polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs +++ b/polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs @@ -2,7 +2,7 @@ use super::sum::SumWindow; use super::*; use no_nulls::{rolling_apply_agg_window, RollingAggWindow}; -struct MeanWindow<'a, T> { +pub(super) struct MeanWindow<'a, T> { sum: SumWindow<'a, T>, } diff --git a/polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs b/polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs index 41bfb4da8e49..87d1c57cbeda 100644 --- a/polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs +++ b/polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs @@ -2,6 +2,7 @@ mod mean; mod min_max; mod quantile; mod sum; +mod variance; use super::*; use crate::utils::CustomIterTools; @@ -18,6 +19,7 @@ pub use mean::rolling_mean; pub use min_max::{rolling_max, rolling_min}; pub use quantile::{rolling_median, rolling_quantile}; pub use sum::rolling_sum; +pub use variance::rolling_var; pub(crate) trait RollingAggWindow<'a, T: NativeType> { fn new(slice: &'a [T], start: usize, end: usize) -> Self; @@ -106,83 +108,26 @@ where )) } -pub(super) fn rolling_apply( - values: &[T], - window_size: usize, - min_periods: usize, - det_offsets_fn: Fo, - aggregator: Fa, -) -> ArrayRef -where - Fo: Fn(Idx, WindowSize, Len) -> (Start, End), - Fa: Fn(&[T]) -> K, - K: NativeType, - T: Debug, -{ - let len = values.len(); - let out = (0..len) - .map(|idx| { - let (start, end) = det_offsets_fn(idx, window_size, len); - let vals = unsafe { values.get_unchecked(start..end) }; - aggregator(vals) - }) - .collect_trusted::>(); - - let validity = create_validity(min_periods, len as usize, window_size, det_offsets_fn); - Arc::new(PrimitiveArray::from_data( - K::PRIMITIVE.into(), - out.into(), - validity.map(|b| b.into()), - )) -} - -pub(crate) fn compute_var(vals: &[T]) -> T -where - T: Float + std::ops::AddAssign + std::fmt::Debug, -{ - let mut count = T::zero(); - let mut sum = T::zero(); - let mut sum_of_squares = T::zero(); - - for &val in vals { - sum += val; - sum_of_squares += val * val; - count += T::one(); - } - - let mean = sum / count; - // apply Bessel's correction - ((sum_of_squares / count) - mean * mean) / (count - T::one()) * count -} - fn compute_var_weights(vals: &[T], weights: &[T]) -> T where T: Float + std::ops::AddAssign, { let weighted_iter = vals.iter().zip(weights).map(|(x, y)| *x * *y); - let mut count = T::zero(); let mut sum = T::zero(); let mut sum_of_squares = T::zero(); for val in weighted_iter { sum += val; sum_of_squares += val * val; - count += T::one(); } + let count = NumCast::from(vals.len()).unwrap(); let mean = sum / count; // apply Bessel's correction ((sum_of_squares / count) - mean * mean) / (count - T::one()) * count } -pub(crate) fn compute_mean(values: &[T]) -> T -where - T: Float + std::iter::Sum, -{ - values.iter().copied().sum::() / T::from(values.len()).unwrap() -} - pub(crate) fn compute_mean_weights(values: &[T], weights: &[T]) -> T where T: Float + std::iter::Sum, @@ -205,47 +150,3 @@ where .map(|v| NumCast::from(*v).unwrap()) .collect::>() } - -pub fn rolling_var( - values: &[T], - window_size: usize, - min_periods: usize, - center: bool, - weights: Option<&[f64]>, -) -> ArrayRef -where - T: NativeType + Float + std::ops::AddAssign, -{ - match (center, weights) { - (true, None) => rolling_apply( - values, - window_size, - min_periods, - det_offsets_center, - compute_var, - ), - (false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_var), - (true, Some(weights)) => { - let weights = coerce_weights(weights); - rolling_apply_weights( - values, - window_size, - min_periods, - det_offsets_center, - compute_var_weights, - &weights, - ) - } - (false, Some(weights)) => { - let weights = coerce_weights(weights); - rolling_apply_weights( - values, - window_size, - min_periods, - det_offsets, - compute_var_weights, - &weights, - ) - } - } -} diff --git a/polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs b/polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs new file mode 100644 index 000000000000..f2571af58d7a --- /dev/null +++ b/polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs @@ -0,0 +1,206 @@ +use super::mean::MeanWindow; +use super::*; +use no_nulls::{rolling_apply_agg_window, RollingAggWindow}; + +pub(super) struct SumSquaredWindow<'a, T> { + slice: &'a [T], + sum_of_squares: T, + last_start: usize, + last_end: usize, +} + +impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul> + RollingAggWindow<'a, T> for SumSquaredWindow<'a, T> +{ + fn new(slice: &'a [T], start: usize, end: usize) -> Self { + let sum = slice[start..end].iter().map(|v| *v * *v).sum::(); + Self { + slice, + sum_of_squares: sum, + last_start: start, + last_end: end, + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> T { + // remove elements that should leave the window + let mut recompute_sum = false; + for idx in self.last_start..start { + // safety + // we are in bounds + let leaving_value = self.slice.get_unchecked(idx); + + if T::is_float() && leaving_value.is_nan() { + recompute_sum = true; + break; + } + + self.sum_of_squares -= *leaving_value * *leaving_value; + } + self.last_start = start; + + // we traverese all values and compute + if T::is_float() && recompute_sum { + self.sum_of_squares = self + .slice + .get_unchecked(start..end) + .iter() + .map(|v| *v * *v) + .sum::(); + } + // the max has not left the window, so we only check + // if the entering values are larger + else { + for idx in self.last_end..end { + let entering_value = *self.slice.get_unchecked(idx); + self.sum_of_squares += entering_value * entering_value; + } + } + self.last_end = end; + self.sum_of_squares + } +} + +// E[(xi - E[x])^2] +// can be expanded to +// E[x^2] - E[x]^2 +struct VarWindow<'a, T> { + mean: MeanWindow<'a, T>, + sum_of_squares: SumSquaredWindow<'a, T>, +} + +impl< + 'a, + T: NativeType + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Sub, + > RollingAggWindow<'a, T> for VarWindow<'a, T> +{ + fn new(slice: &'a [T], start: usize, end: usize) -> Self { + Self { + mean: MeanWindow::new(slice, start, end), + sum_of_squares: SumSquaredWindow::new(slice, start, end), + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> T { + let count = NumCast::from(end - start).unwrap(); + let sum_of_squares = self.sum_of_squares.update(start, end); + let mean_of_squares = sum_of_squares / count; + let mean = self.mean.update(start, end); + let var = mean_of_squares - mean * mean; + // apply Bessel's correction + var / (count - T::one()) * count + } +} + +pub fn rolling_var( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, +) -> ArrayRef +where + T: NativeType + + Float + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Sub, +{ + match (center, weights) { + (true, None) => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + det_offsets_center, + ), + (false, None) => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + det_offsets, + ), + (true, Some(weights)) => { + let weights = coerce_weights(weights); + super::rolling_apply_weights( + values, + window_size, + min_periods, + det_offsets_center, + compute_var_weights, + &weights, + ) + } + (false, Some(weights)) => { + let weights = coerce_weights(weights); + super::rolling_apply_weights( + values, + window_size, + min_periods, + det_offsets, + compute_var_weights, + &weights, + ) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rolling_var() { + let values = &[1.0f64, 5.0, 3.0, 4.0]; + + let out = rolling_var(values, 2, 2, false, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]); + + let out = rolling_var(values, 2, 1, false, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out + .into_iter() + .map(|v| v.copied().unwrap()) + .collect::>(); + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!("{:?}", &[f64::nan(), 8.0, 2.0, 0.5]) + ); + // test nan handling. + let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; + let out = rolling_var(values, 3, 3, false, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!( + "{:?}", + &[ + None, + None, + Some(52.33333333333333), + Some(f64::nan()), + Some(f64::nan()), + Some(f64::nan()), + Some(0.9999999999999964) + ] + ) + ); + } +} diff --git a/polars/polars-arrow/src/kernels/rolling/nulls/mean.rs b/polars/polars-arrow/src/kernels/rolling/nulls/mean.rs index 6ae7b894edd9..90910a68cd36 100644 --- a/polars/polars-arrow/src/kernels/rolling/nulls/mean.rs +++ b/polars/polars-arrow/src/kernels/rolling/nulls/mean.rs @@ -2,19 +2,13 @@ use super::sum::SumWindow; use super::*; use super::{rolling_apply_agg_window, RollingAggWindow}; -struct MeanWindow<'a, T> { +pub(super) struct MeanWindow<'a, T> { sum: SumWindow<'a, T>, } impl< 'a, - T: NativeType - + IsFloat - + PartialOrd - + Add - + Sub - + NumCast - + Div, + T: NativeType + IsFloat + Add + Sub + NumCast + Div, > RollingAggWindow<'a, T> for MeanWindow<'a, T> { unsafe fn new( @@ -31,7 +25,6 @@ impl< unsafe fn update(&mut self, start: usize, end: usize) -> Option { let sum = self.sum.update(start, end); - dbg!(sum); sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap()) } } diff --git a/polars/polars-arrow/src/kernels/rolling/nulls/mod.rs b/polars/polars-arrow/src/kernels/rolling/nulls/mod.rs index 54612d5e0311..d62eded9a10d 100644 --- a/polars/polars-arrow/src/kernels/rolling/nulls/mod.rs +++ b/polars/polars-arrow/src/kernels/rolling/nulls/mod.rs @@ -2,12 +2,14 @@ mod mean; mod min_max; mod quantile; mod sum; +mod variance; use super::*; pub use mean::rolling_mean; pub use min_max::{rolling_max, rolling_min}; pub use quantile::{rolling_median, rolling_quantile}; pub use sum::rolling_sum; +pub use variance::rolling_var; pub(crate) trait RollingAggWindow<'a, T: NativeType> { unsafe fn new( @@ -73,155 +75,6 @@ where )) } -fn rolling_apply( - values: &[T], - bitmap: &Bitmap, - window_size: usize, - min_periods: usize, - det_offsets_fn: Fo, - aggregator: Fa, -) -> ArrayRef -where - Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy, - // &[T] -> values of array - // &[u8] -> validity bytes - // usize -> offset in validity bytes array - // usize -> min_periods - Fa: Fn(&[T], &[u8], usize, usize) -> Option, - K: NativeType + Default, -{ - let len = values.len(); - let (validity_bytes, offset, _) = bitmap.as_slice(); - - let mut validity = match create_validity(min_periods, len as usize, window_size, det_offsets_fn) - { - Some(v) => v, - None => { - let mut validity = MutableBitmap::with_capacity(len); - validity.extend_constant(len, true); - validity - } - }; - - let out = (0..len) - .map(|idx| { - let (start, end) = det_offsets_fn(idx, window_size, len); - let vals = unsafe { values.get_unchecked(start..end) }; - match aggregator(vals, validity_bytes, offset + start, min_periods) { - Some(val) => val, - None => { - validity.set(idx, false); - K::default() - } - } - }) - .collect_trusted::>(); - - Arc::new(PrimitiveArray::from_data( - K::PRIMITIVE.into(), - out.into(), - Some(validity.into()), - )) -} - -fn compute_mean( - values: &[T], - validity_bytes: &[u8], - offset: usize, - min_periods: usize, -) -> Option -where - T: NativeType + std::iter::Sum + Zero + AddAssign + Float, -{ - let null_count = count_zeros(validity_bytes, offset, values.len()); - if null_count == 0 { - Some(no_nulls::compute_mean(values)) - } else if (values.len() - null_count) < min_periods { - None - } else { - let mut out = T::zero(); - let mut count = T::zero(); - for (i, val) in values.iter().enumerate() { - // Safety: - // in bounds - if unsafe { get_bit_unchecked(validity_bytes, offset + i) } { - out += *val; - count += One::one() - } - } - Some(out / count) - } -} - -pub(crate) fn compute_var( - values: &[T], - validity_bytes: &[u8], - offset: usize, - min_periods: usize, -) -> Option -where - T: NativeType + std::iter::Sum + Zero + AddAssign + Float, -{ - let null_count = count_zeros(validity_bytes, offset, values.len()); - if null_count == 0 { - Some(no_nulls::compute_var(values)) - } else if (values.len() - null_count) < min_periods { - None - } else { - match compute_mean(values, validity_bytes, offset, min_periods) { - None => None, - Some(mean) => { - let mut sum = T::zero(); - let mut count = T::zero(); - for (i, val) in values.iter().enumerate() { - // Safety: - // in bounds - if unsafe { get_bit_unchecked(validity_bytes, offset + i) } { - let v = *val - mean; - sum += v * v; - count += One::one() - } - } - Some(sum / (count - T::one())) - } - } - } -} - -pub fn rolling_var( - arr: &PrimitiveArray, - window_size: usize, - min_periods: usize, - center: bool, - weights: Option<&[f64]>, -) -> ArrayRef -where - T: NativeType + std::iter::Sum + Zero + AddAssign + Float, -{ - if weights.is_some() { - panic!("weights not yet supported on array with null values") - } - if center { - rolling_apply( - arr.values().as_slice(), - arr.validity().as_ref().unwrap(), - window_size, - min_periods, - det_offsets_center, - compute_var, - ) - } else { - rolling_apply( - arr.values().as_slice(), - arr.validity().as_ref().unwrap(), - window_size, - min_periods, - det_offsets, - compute_var, - ) - } -} - #[cfg(test)] mod test { use super::*; @@ -229,6 +82,16 @@ mod test { use arrow::buffer::Buffer; use arrow::datatypes::DataType; + fn get_null_arr() -> PrimitiveArray { + // 1, None, -1, 4 + let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]); + PrimitiveArray::from_data( + DataType::Float64, + buf, + Some(Bitmap::from(&[true, false, true, true])), + ) + } + #[test] fn test_rolling_sum_nulls() { let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]); @@ -266,13 +129,8 @@ mod test { #[test] fn test_rolling_mean_nulls() { - // 1, None, -1, 4 - let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]); - let arr = &PrimitiveArray::from_data( - DataType::Float64, - buf, - Some(Bitmap::from(&[true, false, true, true])), - ); + let arr = get_null_arr(); + let arr = &arr; let out = rolling_mean(arr, 2, 2, false, None); let out = out.as_any().downcast_ref::>().unwrap(); @@ -290,6 +148,36 @@ mod test { assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]); } + #[test] + fn test_rolling_var_nulls() { + let arr = get_null_arr(); + let arr = &arr; + + let out = rolling_var(arr, 3, 1, false, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out + .into_iter() + .map(|v| v.copied().unwrap()) + .collect::>(); + + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!("{:?}", &[f64::nan(), f64::nan(), 2.0, 12.5]) + ); + + let out = rolling_var(arr, 4, 1, false, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out + .into_iter() + .map(|v| v.copied().unwrap()) + .collect::>(); + assert_eq!( + format!("{:?}", out.as_slice()), + format!("{:?}", &[f64::nan(), f64::nan(), 2.0, 6.333333333333334]) + ); + } + #[test] fn test_rolling_max_no_nulls() { let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]); diff --git a/polars/polars-arrow/src/kernels/rolling/nulls/sum.rs b/polars/polars-arrow/src/kernels/rolling/nulls/sum.rs index 7ec2409f3086..279d166fbdd7 100644 --- a/polars/polars-arrow/src/kernels/rolling/nulls/sum.rs +++ b/polars/polars-arrow/src/kernels/rolling/nulls/sum.rs @@ -12,9 +12,7 @@ pub(super) struct SumWindow<'a, T> { min_periods: usize, } -impl<'a, T: NativeType + IsFloat + PartialOrd + Add + Sub> - SumWindow<'a, T> -{ +impl<'a, T: NativeType + IsFloat + Add + Sub> SumWindow<'a, T> { // compute sum from the entire window unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { let mut sum = None; @@ -37,8 +35,8 @@ impl<'a, T: NativeType + IsFloat + PartialOrd + Add + Sub + Sub> - RollingAggWindow<'a, T> for SumWindow<'a, T> +impl<'a, T: NativeType + IsFloat + Add + Sub> RollingAggWindow<'a, T> + for SumWindow<'a, T> { unsafe fn new( slice: &'a [T], diff --git a/polars/polars-arrow/src/kernels/rolling/nulls/variance.rs b/polars/polars-arrow/src/kernels/rolling/nulls/variance.rs new file mode 100644 index 000000000000..d12486170927 --- /dev/null +++ b/polars/polars-arrow/src/kernels/rolling/nulls/variance.rs @@ -0,0 +1,203 @@ +use super::*; +use mean::MeanWindow; +use nulls; +use nulls::{rolling_apply_agg_window, RollingAggWindow}; + +pub struct SumSquaredWindow<'a, T> { + slice: &'a [T], + validity: &'a Bitmap, + sum_of_squares: Option, + last_start: usize, + last_end: usize, + null_count: usize, + min_periods: usize, +} + +impl<'a, T: NativeType + IsFloat + Add + Sub + Mul> + SumSquaredWindow<'a, T> +{ + // compute sum from the entire window + unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { + let mut sum_of_squares = None; + let mut idx = start; + self.null_count = 0; + for value in (&self.slice[start..end]).iter() { + let valid = self.validity.get_bit_unchecked(idx); + if valid { + match sum_of_squares { + None => sum_of_squares = Some(*value * *value), + Some(current) => sum_of_squares = Some(*value * *value + current), + } + } else { + self.null_count += 1; + } + idx += 1; + } + self.sum_of_squares = sum_of_squares; + sum_of_squares + } +} + +impl<'a, T: NativeType + IsFloat + Add + Sub + Mul> + RollingAggWindow<'a, T> for SumSquaredWindow<'a, T> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + min_periods: usize, + ) -> Self { + let mut out = Self { + slice, + validity, + sum_of_squares: None, + last_start: start, + last_end: end, + null_count: 0, + min_periods, + }; + out.compute_sum_and_null_count(start, end); + out + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + // remove elements that should leave the window + let mut recompute_sum = false; + for idx in self.last_start..start { + // safety + // we are in bounds + let valid = self.validity.get_bit_unchecked(idx); + if valid { + let leaving_value = *self.slice.get_unchecked(idx); + + // if the leaving value is nan we need to recompute the window + if T::is_float() && leaving_value.is_nan() { + recompute_sum = true; + break; + } + self.sum_of_squares = self + .sum_of_squares + .map(|v| v - leaving_value * leaving_value) + } else { + // null value leaving the window + self.null_count -= 1; + + // self.sum is None and the leaving value is None + // if the entering value is valid, we might get a new sum. + if self.sum_of_squares.is_none() { + recompute_sum = true; + break; + } + } + } + self.last_start = start; + + // we traverse all values and compute + if recompute_sum { + self.compute_sum_and_null_count(start, end); + } else { + for idx in self.last_end..end { + let valid = self.validity.get_bit_unchecked(idx); + + if valid { + let value = *self.slice.get_unchecked(idx); + let value = value * value; + match self.sum_of_squares { + None => self.sum_of_squares = Some(value), + Some(current) => self.sum_of_squares = Some(current + value), + } + } else { + // null value entering the window + self.null_count += 1; + } + } + } + self.last_end = end; + if ((end - start) - self.null_count) < self.min_periods { + None + } else { + self.sum_of_squares + } + } +} + +// E[(xi - E[x])^2] +// can be expanded to +// E[x^2] - E[x]^2 +struct VarWindow<'a, T> { + mean: MeanWindow<'a, T>, + sum_of_squares: SumSquaredWindow<'a, T>, +} + +impl< + 'a, + T: NativeType + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Add + + Sub, + > RollingAggWindow<'a, T> for VarWindow<'a, T> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + min_periods: usize, + ) -> Self { + Self { + mean: MeanWindow::new(slice, validity, start, end, min_periods), + sum_of_squares: SumSquaredWindow::new(slice, validity, start, end, min_periods), + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let sum_of_squares = self.sum_of_squares.update(start, end)?; + let null_count = self.sum_of_squares.null_count; + let count = NumCast::from(end - start - null_count).unwrap(); + + let mean_of_squares = sum_of_squares / count; + let mean = self.mean.update(start, end)?; + let var = mean_of_squares - mean * mean; + // apply Bessel's correction + Some(var / (count - T::one()) * count) + } +} + +pub fn rolling_var( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, +) -> ArrayRef +where + T: NativeType + std::iter::Sum + Zero + AddAssign + SubAssign + IsFloat + Float, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + if center { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets_center, + ) + } else { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + ) + } +} diff --git a/polars/polars-core/src/chunked_array/ops/rolling_window.rs b/polars/polars-core/src/chunked_array/ops/rolling_window.rs index f367b2e4e6a6..b7246b6ef447 100644 --- a/polars/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/polars/polars-core/src/chunked_array/ops/rolling_window.rs @@ -390,7 +390,7 @@ mod inner_mod { where ChunkedArray: IntoSeries, T: PolarsFloatType, - T::Native: Float + IsFloat, + T::Native: Float + IsFloat + SubAssign, { /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch. pub fn rolling_apply_float(&self, window_size: usize, f: F) -> Result diff --git a/py-polars/tests/test_struct.py b/py-polars/tests/test_struct.py index 7538ec3d11cf..0b9ea23c582e 100644 --- a/py-polars/tests/test_struct.py +++ b/py-polars/tests/test_struct.py @@ -1,3 +1,4 @@ +import typing from datetime import datetime import pandas as pd @@ -499,13 +500,17 @@ def test_struct_arr_eval() -> None: } +@typing.no_type_check def test_arr_unique() -> None: df = pl.DataFrame( {"col_struct": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 1, "b": 11}]]} ) - assert df.with_column(pl.col("col_struct").arr.unique().alias("unique")).to_dict( - False - ) == { - "col_struct": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 1, "b": 11}]], - "unique": [[{"a": 2, "b": 12}, {"a": 1, "b": 11}]], - } + # the order is unpredictable + unique = df.with_column(pl.col("col_struct").arr.unique().alias("unique"))[ + "unique" + ].to_list() + assert len(unique) == 1 + unique_el = unique[0] + assert len(unique_el) == 2 + assert {"a": 2, "b": 12} in unique_el + assert {"a": 1, "b": 11} in unique_el