Skip to content

Commit

Permalink
Full update of weighted index by assigning weights
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This commit adds a variant to `WeightedError`.
  • Loading branch information
SuperFluffy committed Oct 22, 2021
1 parent 320acef commit f6187ec
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 4 deletions.
40 changes: 36 additions & 4 deletions benches/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,53 @@ use rand::distributions::WeightedIndex;
use rand::Rng;
use test::Bencher;

#[bench]
fn weighted_index_assignment(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..];
let mut distr = WeightedIndex::new(weights).unwrap();
b.iter(|| {
distr.assign_new_weights(weights).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_assignment_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
let mut distr = WeightedIndex::new(&weights).unwrap();
b.iter(|| {
distr.assign_new_weights(&weights).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_creation(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
let weights = &[1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7][..];
b.iter(|| {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
let distr = WeightedIndex::new(weights).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_creation_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
b.iter(|| {
let distr = WeightedIndex::new(&weights).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_modification(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..];
let mut distr = WeightedIndex::new(weights).unwrap();
b.iter(|| {
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
rng.sample(&distr)
Expand Down
104 changes: 104 additions & 0 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::distributions::Distribution;
use crate::Rng;
use core::cmp::PartialOrd;
use core::fmt;
use core::iter::ExactSizeIterator;

// Note that this whole module is only imported if feature="alloc" is enabled.
use alloc::vec::Vec;
Expand Down Expand Up @@ -130,6 +131,55 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
})
}

/// Reuses the weighted index by assigning a new set of weights without changing the number of
/// weights.
///
/// Returns an error if:
///
/// + the number of items in the iterator does not match the number of items used to create the
/// distribution;
/// + the iterator yields invalid values (such as `f64::NAN`);
/// + the weights yielded by the iterator sum up to zero.
///
/// NOTE: If this method fails the distribution should no longer be used for sampling, because
/// results of sampling from it are undefined.
pub fn assign_new_weights<I>(&mut self, weights: I) -> Result<(), WeightedError >
where
I: IntoIterator,
I::IntoIter: ExactSizeIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
{
let mut iter = weights.into_iter();

if iter.len() != self.cumulative_weights.len() + 1 {
return Err(WeightedError::LenMismatch);
}

let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
let zero = <X as Default>::default();

if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}

for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) {
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
*c = total_weight.clone();
total_weight += w.borrow();
}

if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
};

self.weight_distribution = X::Sampler::new(zero, total_weight.clone());
self.total_weight = total_weight;
Ok(())
}

/// Update a subset of weights, without changing the number of weights.
///
/// `new_weights` must be sorted by the index.
Expand Down Expand Up @@ -389,6 +439,55 @@ mod test {
}
}

#[test]
fn test_assign_new_weights() {
let data = [
(
&[10u32, 2, 3, 4][..],
&[10, 100, 4, 4][..],
),
(
&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
),
];

for &(weights, new_weights) in data.iter() {
let total_weight = weights.iter().sum::<u32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);

distr.assign_new_weights(weights).unwrap();
assert_eq!(distr.total_weight, total_weight);

let new_total_weight = new_weights.iter().sum::<u32>();
let new_distr = WeightedIndex::new(new_weights.to_vec()).unwrap();
distr.assign_new_weights(new_weights).unwrap();
assert_eq!(new_total_weight, new_distr.total_weight);
assert_eq!(new_total_weight, distr.total_weight);
assert_eq!(new_distr.cumulative_weights, distr.cumulative_weights);
}
}

#[test]
fn assigning_error_states() {
{
let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap();
let res = distr.assign_new_weights(&[1.0f64, 2.0, 3.0][..]);
assert_eq!(res, Err(WeightedError::LenMismatch));
}
{
let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap();
let res = distr.assign_new_weights(&[1.0f64, 2.0, ::core::f64::NAN, 0.0][..]);
assert_eq!(res, Err(WeightedError::InvalidWeight));
}
{
let mut distr = WeightedIndex::new(&[1u32, 2, 3, 0][..]).unwrap();
let res = distr.assign_new_weights(&[0u32, 0, 0, 0][..]);
assert_eq!(res, Err(WeightedError::AllWeightsZero));
}
}

#[test]
fn value_stability() {
fn test_samples<X: SampleUniform + PartialOrd, I>(
Expand Down Expand Up @@ -436,6 +535,10 @@ pub enum WeightedError {

/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,

/// Have to provide exactly as many weights when assigning as were present when constructing
/// the weighted index.
LenMismatch,
}

#[cfg(feature = "std")]
Expand All @@ -448,6 +551,7 @@ impl fmt::Display for WeightedError {
WeightedError::InvalidWeight => "A weight is invalid in distribution",
WeightedError::AllWeightsZero => "All weights are zero in distribution",
WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
WeightedError::LenMismatch => "Length mismatch between previous and provided weights",
})
}
}

0 comments on commit f6187ec

Please sign in to comment.