From 00dd89e3c8fec158a5547f415545d7bd0f1a0d8c Mon Sep 17 00:00:00 2001 From: Janis Goldschmidt Date: Tue, 19 Oct 2021 10:11:47 +0200 Subject: [PATCH] Full update of weighted index by assigning weights BREAKING CHANGE: This commit adds a variant to `WeightedError`. --- benches/weighted.rs | 36 +++++++++- src/distributions/weighted_index.rs | 104 ++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/benches/weighted.rs b/benches/weighted.rs index 68722908a9e..94d5f608406 100644 --- a/benches/weighted.rs +++ b/benches/weighted.rs @@ -14,21 +14,53 @@ use rand::distributions::WeightedIndex; use rand::Rng; use test::Bencher; +#[bench] +fn weighted_index_assignment(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let mut distr = WeightedIndex::new(weights).unwrap(); + b.iter(|| { + distr.assign_new_weights(weights).unwrap(); + rng.sample(&distr) + }) +} + +#[bench] +fn weighted_index_assignment_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + let mut distr = WeightedIndex::new(&weights).unwrap(); + b.iter(|| { + distr.assign_new_weights(&weights).unwrap(); + rng.sample(&distr) + }) +} + #[bench] fn weighted_index_creation(b: &mut Bencher) { let mut rng = rand::thread_rng(); let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; b.iter(|| { - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + let distr = WeightedIndex::new(weights).unwrap(); rng.sample(distr) }) } +#[bench] +fn weighted_index_creation_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + b.iter(|| { + let distr = WeightedIndex::new(&weights).unwrap(); + rng.sample(&distr) + }) +} + #[bench] fn weighted_index_modification(b: &mut Bencher) { let mut rng = rand::thread_rng(); let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + let mut distr = WeightedIndex::new(weights).unwrap(); b.iter(|| { distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); rng.sample(&distr) diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 32da37f6cd3..8b51513b9dd 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -13,6 +13,7 @@ use crate::distributions::Distribution; use crate::Rng; use core::cmp::PartialOrd; use core::fmt; +use core::iter::ExactSizeIterator; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; @@ -130,6 +131,55 @@ impl WeightedIndex { }) } + /// Reuses the weighted index by assigning a new set of weights without changing the number of + /// weights. + /// + /// Returns an error if: + /// + /// + the number of items in the iterator does not match the number of items used to create the + /// distribution; + /// + the iterator yields invalid values (such as `f64::NAN`); + /// + the weights yielded by the iterator sum up to zero. + /// + /// NOTE: If this method fails the distribution should no longer be used for sampling, because + /// results of sampling from it are undefined. + pub fn assign_new_weights(&mut self, weights: I) -> Result<(), WeightedError > + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + let mut iter = weights.into_iter(); + + if iter.len() != self.cumulative_weights.len() + 1 { + return Err(WeightedError::LenMismatch); + } + + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + let zero = ::default(); + + if !(total_weight >= zero) { + return Err(WeightedError::InvalidWeight); + } + + for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) { + if !(w.borrow() >= &zero) { + return Err(WeightedError::InvalidWeight); + } + *c = total_weight.clone(); + total_weight += w.borrow(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + }; + + self.weight_distribution = X::Sampler::new(zero, total_weight.clone()); + self.total_weight = total_weight; + Ok(()) + } + /// Update a subset of weights, without changing the number of weights. /// /// `new_weights` must be sorted by the index. @@ -389,6 +439,55 @@ mod test { } } + #[test] + fn test_assign_new_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for &(weights, new_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.assign_new_weights(weights).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + let new_total_weight = new_weights.iter().sum::(); + let new_distr = WeightedIndex::new(new_weights.to_vec()).unwrap(); + distr.assign_new_weights(new_weights).unwrap(); + assert_eq!(new_total_weight, new_distr.total_weight); + assert_eq!(new_total_weight, distr.total_weight); + assert_eq!(new_distr.cumulative_weights, distr.cumulative_weights); + } + } + + #[test] + fn assigning_error_states() { + { + let mut distr = WeightedIndex::new([1.0f64, 2.0, 3.0, 0.0]).unwrap(); + let res = distr.assign_new_weights([1.0f64, 2.0, 3.0]); + assert_eq!(res, Err(WeightedError::LenMismatch)); + } + { + let mut distr = WeightedIndex::new([1.0f64, 2.0, 3.0, 0.0]).unwrap(); + let res = distr.assign_new_weights([1.0f64, 2.0, f64::NAN, 0.0]); + assert_eq!(res, Err(WeightedError::InvalidWeight)); + } + { + let mut distr = WeightedIndex::new([1u32, 2, 3, 0]).unwrap(); + let res = distr.assign_new_weights([0u32, 0, 0, 0]); + assert_eq!(res, Err(WeightedError::AllWeightsZero)); + } + } + #[test] fn value_stability() { fn test_samples( @@ -436,6 +535,10 @@ pub enum WeightedError { /// Too many weights are provided (length greater than `u32::MAX`) TooMany, + + /// Have to provide exactly as many weights when assigning as were present when constructing + /// the weighted index. + LenMismatch, } #[cfg(feature = "std")] @@ -448,6 +551,7 @@ impl fmt::Display for WeightedError { WeightedError::InvalidWeight => "A weight is invalid in distribution", WeightedError::AllWeightsZero => "All weights are zero in distribution", WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", + WeightedError::LenMismatch => "Length mismatch between previous and provided weights", }) } }