Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full update of weighted index by assigning weights #1194

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions benches/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,106 @@ use rand::Rng;
use test::Bencher;

#[bench]
fn weighted_index_creation(b: &mut Bencher) {
fn weighted_index_assignment(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..];
let mut distr = WeightedIndex::new(weights).unwrap();
b.iter(|| {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
distr.assign_new_weights(weights).unwrap();
SuperFluffy marked this conversation as resolved.
Show resolved Hide resolved
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_assignment_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
let mut distr = WeightedIndex::new(&weights).unwrap();
b.iter(|| {
distr.assign_new_weights(&weights).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_new(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = &[1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7][..];
b.iter(|| {
let distr = WeightedIndex::new(weights).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_new_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
b.iter(|| {
let distr = WeightedIndex::new(&weights).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_from_cumulative_weights(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = vec![1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
let cumulative_weights: Vec<_> = weights
.into_iter()
.scan(0, |acc, item| {
*acc += item;
Some(*acc)
})
.collect();
b.iter(|| {
let distr = WeightedIndex::from_cumulative_weights_unchecked(cumulative_weights.clone());
rng.sample(distr)
})
}

#[bench]
fn weighted_index_from_cumulative_weights_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
let cumulative_weights: Vec<_> = weights
.into_iter()
.scan(0, |acc, item| {
*acc += item;
Some(*acc)
})
.collect();
b.iter(|| {
let distr = WeightedIndex::from_cumulative_weights_unchecked(cumulative_weights.clone());
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_from_weights(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = vec![1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
b.iter(|| {
let distr = WeightedIndex::from_weights(weights.clone()).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_from_weights_large(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = (0u64..1000).collect();
b.iter(|| {
let distr = WeightedIndex::from_weights(weights.clone()).unwrap();
rng.sample(&distr)
})
}

#[bench]
fn weighted_index_modification(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
let weights = &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..];
let mut distr = WeightedIndex::new(weights).unwrap();
b.iter(|| {
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
rng.sample(&distr)
Expand Down
168 changes: 168 additions & 0 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::distributions::Distribution;
use crate::Rng;
use core::cmp::PartialOrd;
use core::fmt;
use core::iter::ExactSizeIterator;

// Note that this whole module is only imported if feature="alloc" is enabled.
use alloc::vec::Vec;
Expand Down Expand Up @@ -130,6 +131,119 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
})
}

/// Reuses the weighted index by assigning a new set of weights without changing the number of
/// weights.
///
/// Returns an error if:
///
/// + the number of items in the iterator does not match the number of items used to create the
/// distribution;
/// + the iterator yields invalid values (such as `f64::NAN`);
/// + the weights yielded by the iterator sum up to zero.
///
/// NOTE: If this method fails the distribution should no longer be used for sampling, because
/// results of sampling from it are undefined.
pub fn assign_new_weights<I>(&mut self, weights: I) -> Result<(), WeightedError >
vks marked this conversation as resolved.
Show resolved Hide resolved
where
I: IntoIterator,
I::IntoIter: ExactSizeIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
{
let mut iter = weights.into_iter();

if iter.len() != self.cumulative_weights.len() + 1 {
return Err(WeightedError::LenMismatch);
}

let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
let zero = <X as Default>::default();

if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}

for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) {
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
*c = total_weight.clone();
total_weight += w.borrow();
}

if total_weight == zero {
dhardy marked this conversation as resolved.
Show resolved Hide resolved
return Err(WeightedError::AllWeightsZero);
};

self.weight_distribution = X::Sampler::new(zero, total_weight.clone());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still a problem: this panics if total_weight is +inf, and we don't catch panics.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but here WeightedIndex::new suffers from the same issue. So if we want to address this, both assign_new_weights and new should be changed in a new PR, I think.

self.total_weight = total_weight;
Ok(())
}

/// Create a `WeightedIndex` from a vector of cumulative weights without verification.
pub fn from_cumulative_weights_unchecked(ws: Vec<X>) -> Self
where
X: Clone + Default,
{
let mut cumulative_weights = ws;
let total_weight = cumulative_weights.pop().unwrap();
let zero = <X as Default>::default();
let weight_distribution = X::Sampler::new(zero, total_weight.clone());
Self {
cumulative_weights,
total_weight,
weight_distribution,
}
}

/// Create a `WeightedIndex` from a vector of cumulative weights without verification.
pub fn from_weights(ws: Vec<X>) -> Result<Self, WeightedError>
where
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
for<'a> &'a X: SampleBorrow<X>,
{
let mut cumulative_weights = ws;
let mut iter = cumulative_weights.iter_mut();

let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
let zero = <X as Default>::default();

if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}

for w in iter {
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
total_weight += w.borrow();
*w = total_weight.clone();
}

if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
};

let weight_distribution = X::Sampler::new(zero, total_weight.clone());
cumulative_weights.pop().unwrap();
Ok(Self {
cumulative_weights,
total_weight,
weight_distribution,
})
}

/// Remove the inner vector containing the cumulative weights.
pub fn into_cumulative_weights(self) -> Vec<X> {
let Self {
mut cumulative_weights,
total_weight,
..
} = self;
cumulative_weights.push(total_weight);
cumulative_weights
}

/// Update a subset of weights, without changing the number of weights.
///
/// `new_weights` must be sorted by the index.
Expand Down Expand Up @@ -389,6 +503,55 @@ mod test {
}
}

#[test]
fn test_assign_new_weights() {
let data = [
(
&[10u32, 2, 3, 4][..],
&[10, 100, 4, 4][..],
),
(
&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
),
];

for &(weights, new_weights) in data.iter() {
let total_weight = weights.iter().sum::<u32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);

distr.assign_new_weights(weights).unwrap();
assert_eq!(distr.total_weight, total_weight);

let new_total_weight = new_weights.iter().sum::<u32>();
let new_distr = WeightedIndex::new(new_weights.to_vec()).unwrap();
distr.assign_new_weights(new_weights).unwrap();
assert_eq!(new_total_weight, new_distr.total_weight);
assert_eq!(new_total_weight, distr.total_weight);
assert_eq!(new_distr.cumulative_weights, distr.cumulative_weights);
}
}

#[test]
fn assigning_error_states() {
{
let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap();
let res = distr.assign_new_weights(&[1.0f64, 2.0, 3.0][..]);
assert_eq!(res, Err(WeightedError::LenMismatch));
}
{
let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap();
let res = distr.assign_new_weights(&[1.0f64, 2.0, ::core::f64::NAN, 0.0][..]);
assert_eq!(res, Err(WeightedError::InvalidWeight));
}
{
let mut distr = WeightedIndex::new(&[1u32, 2, 3, 0][..]).unwrap();
let res = distr.assign_new_weights(&[0u32, 0, 0, 0][..]);
assert_eq!(res, Err(WeightedError::AllWeightsZero));
}
}

#[test]
fn value_stability() {
fn test_samples<X: SampleUniform + PartialOrd, I>(
Expand Down Expand Up @@ -436,6 +599,10 @@ pub enum WeightedError {

/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,

/// Have to provide exactly as many weights when assigning as were present when constructing
/// the weighted index.
LenMismatch,
SuperFluffy marked this conversation as resolved.
Show resolved Hide resolved
}

#[cfg(feature = "std")]
Expand All @@ -448,6 +615,7 @@ impl fmt::Display for WeightedError {
WeightedError::InvalidWeight => "A weight is invalid in distribution",
WeightedError::AllWeightsZero => "All weights are zero in distribution",
WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
WeightedError::LenMismatch => "Length mismatch between previous and provided weights",
})
}
}