diff --git a/benches/weighted.rs b/benches/weighted.rs index 68722908a9e..6c61bf09022 100644 --- a/benches/weighted.rs +++ b/benches/weighted.rs @@ -15,20 +15,106 @@ use rand::Rng; use test::Bencher; #[bench] -fn weighted_index_creation(b: &mut Bencher) { +fn weighted_index_assignment(b: &mut Bencher) { let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..]; + let mut distr = WeightedIndex::new(weights).unwrap(); b.iter(|| { - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + distr.assign_new_weights(weights).unwrap(); + rng.sample(&distr) + }) +} + +#[bench] +fn weighted_index_assignment_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + let mut distr = WeightedIndex::new(&weights).unwrap(); + b.iter(|| { + distr.assign_new_weights(&weights).unwrap(); + rng.sample(&distr) + }) +} + +#[bench] +fn weighted_index_new(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = &[1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7][..]; + b.iter(|| { + let distr = WeightedIndex::new(weights).unwrap(); rng.sample(distr) }) } +#[bench] +fn weighted_index_new_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + b.iter(|| { + let distr = WeightedIndex::new(&weights).unwrap(); + rng.sample(&distr) + }) +} + +#[bench] +fn weighted_index_from_cumulative_weights(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = vec![1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + let cumulative_weights: Vec<_> = weights + .into_iter() + .scan(0, |acc, item| { + *acc += item; + Some(*acc) + }) + .collect(); + b.iter(|| { + let distr = WeightedIndex::from_cumulative_weights_unchecked(cumulative_weights.clone()); + rng.sample(distr) + }) +} + +#[bench] +fn weighted_index_from_cumulative_weights_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + let cumulative_weights: Vec<_> = weights + .into_iter() + .scan(0, |acc, item| { + *acc += item; + Some(*acc) + }) + .collect(); + b.iter(|| { + let distr = WeightedIndex::from_cumulative_weights_unchecked(cumulative_weights.clone()); + rng.sample(&distr) + }) +} + +#[bench] +fn weighted_index_from_weights(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = vec![1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + b.iter(|| { + let distr = WeightedIndex::from_weights(weights.clone()).unwrap(); + rng.sample(distr) + }) +} + +#[bench] +fn weighted_index_from_weights_large(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights: Vec = (0u64..1000).collect(); + b.iter(|| { + let distr = WeightedIndex::from_weights(weights.clone()).unwrap(); + rng.sample(&distr) + }) +} + #[bench] fn weighted_index_modification(b: &mut Bencher) { let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..]; + let mut distr = WeightedIndex::new(weights).unwrap(); b.iter(|| { distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); rng.sample(&distr) diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 32da37f6cd3..dd8d147810e 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -13,6 +13,7 @@ use crate::distributions::Distribution; use crate::Rng; use core::cmp::PartialOrd; use core::fmt; +use core::iter::ExactSizeIterator; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; @@ -130,6 +131,119 @@ impl WeightedIndex { }) } + /// Reuses the weighted index by assigning a new set of weights without changing the number of + /// weights. + /// + /// Returns an error if: + /// + /// + the number of items in the iterator does not match the number of items used to create the + /// distribution; + /// + the iterator yields invalid values (such as `f64::NAN`); + /// + the weights yielded by the iterator sum up to zero. + /// + /// NOTE: If this method fails the distribution should no longer be used for sampling, because + /// results of sampling from it are undefined. + pub fn assign_new_weights(&mut self, weights: I) -> Result<(), WeightedError > + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + let mut iter = weights.into_iter(); + + if iter.len() != self.cumulative_weights.len() + 1 { + return Err(WeightedError::LenMismatch); + } + + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + let zero = ::default(); + + if !(total_weight >= zero) { + return Err(WeightedError::InvalidWeight); + } + + for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) { + if !(w.borrow() >= &zero) { + return Err(WeightedError::InvalidWeight); + } + *c = total_weight.clone(); + total_weight += w.borrow(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + }; + + self.weight_distribution = X::Sampler::new(zero, total_weight.clone()); + self.total_weight = total_weight; + Ok(()) + } + + /// Create a `WeightedIndex` from a vector of cumulative weights without verification. + pub fn from_cumulative_weights_unchecked(ws: Vec) -> Self + where + X: Clone + Default, + { + let mut cumulative_weights = ws; + let total_weight = cumulative_weights.pop().unwrap(); + let zero = ::default(); + let weight_distribution = X::Sampler::new(zero, total_weight.clone()); + Self { + cumulative_weights, + total_weight, + weight_distribution, + } + } + + /// Create a `WeightedIndex` from a vector of cumulative weights without verification. + pub fn from_weights(ws: Vec) -> Result + where + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + for<'a> &'a X: SampleBorrow, + { + let mut cumulative_weights = ws; + let mut iter = cumulative_weights.iter_mut(); + + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + let zero = ::default(); + + if !(total_weight >= zero) { + return Err(WeightedError::InvalidWeight); + } + + for w in iter { + if !(w.borrow() >= &zero) { + return Err(WeightedError::InvalidWeight); + } + total_weight += w.borrow(); + *w = total_weight.clone(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + }; + + let weight_distribution = X::Sampler::new(zero, total_weight.clone()); + cumulative_weights.pop().unwrap(); + Ok(Self { + cumulative_weights, + total_weight, + weight_distribution, + }) + } + + /// Remove the inner vector containing the cumulative weights. + pub fn into_cumulative_weights(self) -> Vec { + let Self { + mut cumulative_weights, + total_weight, + .. + } = self; + cumulative_weights.push(total_weight); + cumulative_weights + } + /// Update a subset of weights, without changing the number of weights. /// /// `new_weights` must be sorted by the index. @@ -389,6 +503,55 @@ mod test { } } + #[test] + fn test_assign_new_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for &(weights, new_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.assign_new_weights(weights).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + let new_total_weight = new_weights.iter().sum::(); + let new_distr = WeightedIndex::new(new_weights.to_vec()).unwrap(); + distr.assign_new_weights(new_weights).unwrap(); + assert_eq!(new_total_weight, new_distr.total_weight); + assert_eq!(new_total_weight, distr.total_weight); + assert_eq!(new_distr.cumulative_weights, distr.cumulative_weights); + } + } + + #[test] + fn assigning_error_states() { + { + let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap(); + let res = distr.assign_new_weights(&[1.0f64, 2.0, 3.0][..]); + assert_eq!(res, Err(WeightedError::LenMismatch)); + } + { + let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap(); + let res = distr.assign_new_weights(&[1.0f64, 2.0, ::core::f64::NAN, 0.0][..]); + assert_eq!(res, Err(WeightedError::InvalidWeight)); + } + { + let mut distr = WeightedIndex::new(&[1u32, 2, 3, 0][..]).unwrap(); + let res = distr.assign_new_weights(&[0u32, 0, 0, 0][..]); + assert_eq!(res, Err(WeightedError::AllWeightsZero)); + } + } + #[test] fn value_stability() { fn test_samples( @@ -436,6 +599,10 @@ pub enum WeightedError { /// Too many weights are provided (length greater than `u32::MAX`) TooMany, + + /// Have to provide exactly as many weights when assigning as were present when constructing + /// the weighted index. + LenMismatch, } #[cfg(feature = "std")] @@ -448,6 +615,7 @@ impl fmt::Display for WeightedError { WeightedError::InvalidWeight => "A weight is invalid in distribution", WeightedError::AllWeightsZero => "All weights are zero in distribution", WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", + WeightedError::LenMismatch => "Length mismatch between previous and provided weights", }) } }