From a76ab4155a587617c372aa5f39a0aab0a181656d Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 9 Mar 2020 13:10:40 +0000 Subject: [PATCH 1/5] Move alias method WeightedIndex to rand_distr --- rand_distr/src/lib.rs | 6 +++- .../src}/weighted/alias_method.rs | 8 +++++ rand_distr/src/weighted/mod.rs | 21 ++++++++++++ src/distributions/weighted/mod.rs | 33 ++++++++++++++++++- 4 files changed, 66 insertions(+), 2 deletions(-) rename {src/distributions => rand_distr/src}/weighted/alias_method.rs (97%) create mode 100644 rand_distr/src/weighted/mod.rs diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 0e7beb91b5d..de1c10e1dd1 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -68,7 +68,7 @@ //! - [`UnitDisc`] distribution pub use rand::distributions::{ - uniform, weighted, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, + uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, Standard, Uniform, }; @@ -91,6 +91,10 @@ pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; pub use self::utils::Float; pub use self::weibull::{Error as WeibullError, Weibull}; +#[cfg(feature = "alloc")] +pub use self::weighted::{WeightedError, WeightedIndex}; + +#[cfg(feature = "alloc")] pub mod weighted; mod binomial; mod cauchy; diff --git a/src/distributions/weighted/alias_method.rs b/rand_distr/src/weighted/alias_method.rs similarity index 97% rename from src/distributions/weighted/alias_method.rs rename to rand_distr/src/weighted/alias_method.rs index 7d42a35267b..77e342cb1e8 100644 --- a/src/distributions/weighted/alias_method.rs +++ b/rand_distr/src/weighted/alias_method.rs @@ -1,3 +1,11 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + //! This module contains an implementation of alias method for sampling random //! indices with probabilities proportional to a collection of weights. diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs new file mode 100644 index 00000000000..9a23cb3e819 --- /dev/null +++ b/rand_distr/src/weighted/mod.rs @@ -0,0 +1,21 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted index sampling +//! +//! This module provides two implementations for sampling indices: +//! +//! * [`WeightedIndex`] allows `O(log N)` sampling +//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with +//! much greater set-up cost +//! +//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html + +pub mod alias_method; + +pub use pub use rand::distributions::weighted::{WeightedIndex, WeightedError}; diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index 357e3a9f024..47c10c50e1b 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -16,7 +16,38 @@ //! //! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html -pub mod alias_method; +#[allow(missing_docs)] +#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] +pub mod alias_method { + // This module exists to provide a deprecation warning which minimises + // compile errors, but still fails to compile if ever used. + use std::marker::PhantomData; + use super::WeightedError; + + #[derive(Debug)] + pub struct WeightedIndex { + _phantom: PhantomData, + } + impl WeightedIndex { + pub fn new(_weights: Vec) -> Result { + Err(WeightedError::NoItem) + } + } + + pub trait Weight {} + macro_rules! impl_weight { + () => {}; + ($T:ident, $($more:ident,)*) => { + impl Weight for $T {} + impl_weight!($($more,)*); + }; + } + impl_weight!(f64, f32,); + impl_weight!(u8, u16, u32, u64, usize,); + impl_weight!(i8, i16, i32, i64, isize,); + #[cfg(not(target_os = "emscripten"))] + impl_weight!(u128, i128,); +} use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; use crate::distributions::Distribution; From 2061960a99bf822adf5ff8738e39ea5123a78782 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 9 Mar 2020 13:11:23 +0000 Subject: [PATCH 2/5] Move rand::distributions::weighted module out of sub-dir --- src/distributions/{weighted/mod.rs => weighted.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/distributions/{weighted/mod.rs => weighted.rs} (100%) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted.rs similarity index 100% rename from src/distributions/weighted/mod.rs rename to src/distributions/weighted.rs From 967c22ee6d4a5534f925c7ca0a8d629821ef9583 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 9 Mar 2020 13:29:39 +0000 Subject: [PATCH 3/5] Fixes: remove cfg, adjust doc, fix imports --- rand_distr/src/lib.rs | 7 +++---- rand_distr/src/weighted/alias_method.rs | 14 +++++--------- rand_distr/src/weighted/mod.rs | 2 +- src/distributions/weighted.rs | 14 ++++++++------ 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index de1c10e1dd1..a41c579cbcc 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -68,8 +68,8 @@ //! - [`UnitDisc`] distribution pub use rand::distributions::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, - OpenClosed01, Standard, Uniform, + uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, + Standard, Uniform, }; pub use self::binomial::{Binomial, Error as BinomialError}; @@ -91,10 +91,9 @@ pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; pub use self::utils::Float; pub use self::weibull::{Error as WeibullError, Weibull}; -#[cfg(feature = "alloc")] pub use self::weighted::{WeightedError, WeightedIndex}; -#[cfg(feature = "alloc")] pub mod weighted; +pub mod weighted; mod binomial; mod cauchy; diff --git a/rand_distr/src/weighted/alias_method.rs b/rand_distr/src/weighted/alias_method.rs index 77e342cb1e8..8e094029f84 100644 --- a/rand_distr/src/weighted/alias_method.rs +++ b/rand_distr/src/weighted/alias_method.rs @@ -10,15 +10,11 @@ //! indices with probabilities proportional to a collection of weights. use super::WeightedError; -#[cfg(not(feature = "std"))] use crate::alloc::vec; -#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec; -use crate::distributions::uniform::SampleUniform; -use crate::distributions::Distribution; -use crate::distributions::Uniform; -use crate::Rng; +use crate::{uniform::SampleUniform, Distribution, Uniform}; use core::fmt; use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use rand::Rng; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -42,7 +38,7 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; /// # Example /// /// ``` -/// use rand::distributions::weighted::alias_method::WeightedIndex; +/// use rand_distr::weighted::alias_method::WeightedIndex; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; @@ -408,7 +404,7 @@ mod test { test_weighted_index(|x: u128| x as f64); } - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[cfg(not(target_os = "emscripten"))] #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weighted_index_i128() { @@ -456,7 +452,7 @@ mod test { let weights = { let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); - let random_weight_distribution = crate::distributions::Uniform::new_inclusive( + let random_weight_distribution = Uniform::new_inclusive( W::ZERO, W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), ); diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs index 9a23cb3e819..75acacd5792 100644 --- a/rand_distr/src/weighted/mod.rs +++ b/rand_distr/src/weighted/mod.rs @@ -18,4 +18,4 @@ pub mod alias_method; -pub use pub use rand::distributions::weighted::{WeightedIndex, WeightedError}; +pub use rand::distributions::weighted::{WeightedError, WeightedIndex}; diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 47c10c50e1b..090f7c9771b 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -8,13 +8,15 @@ //! Weighted index sampling //! -//! This module provides two implementations for sampling indices: +//! The [`WeightedIndex`] distribution allows `O(log N)` sampling from a +//! sequence of weights. As the name implies, the result is the index in that +//! sequence, which may be used to look up an associated value. //! -//! * [`WeightedIndex`] allows `O(log N)` sampling -//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with -//! much greater set-up cost -//! -//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html +//! Note also that the `rand_distr` crate provides +//! `rand_distr::alias_method::WeightedIndex`, which allows `O(1)` sampling; +//! this distribution however has a much greater set-up cost, thus is only +//! recommended where *many* samples are required. +// TODO: link alias_method impl when published in rand_distr #[allow(missing_docs)] #[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] From 6a3d3a20b0c0d40a4574ca346021a10776ddf69a Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 10 Mar 2020 12:18:35 +0000 Subject: [PATCH 4/5] Fix alloc-no-std build --- src/distributions/weighted.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 090f7c9771b..9b1f91e2e61 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -23,7 +23,8 @@ pub mod alias_method { // This module exists to provide a deprecation warning which minimises // compile errors, but still fails to compile if ever used. - use std::marker::PhantomData; + use core::marker::PhantomData; + #[cfg(not(feature = "std"))] use crate::alloc::vec::Vec; use super::WeightedError; #[derive(Debug)] From 29689f55fcab69a6d035b8c8fcc7ed3cc56f185d Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Fri, 20 Mar 2020 09:16:13 +0000 Subject: [PATCH 5/5] Deprecate use of rand::distributions::weighted module --- src/distributions/mod.rs | 6 +- src/distributions/weighted.rs | 407 +--------------------------- src/distributions/weighted_index.rs | 404 +++++++++++++++++++++++++++ 3 files changed, 413 insertions(+), 404 deletions(-) create mode 100644 src/distributions/weighted_index.rs diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index e7865c27b90..77a29b06802 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -100,8 +100,9 @@ pub use self::bernoulli::{Bernoulli, BernoulliError}; pub use self::float::{Open01, OpenClosed01}; pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; + #[cfg(feature = "alloc")] -pub use self::weighted::{WeightedError, WeightedIndex}; +pub use self::weighted_index::{WeightedError, WeightedIndex}; // The following are all deprecated after being moved to rand_distr #[allow(deprecated)] @@ -155,7 +156,10 @@ pub mod uniform; #[cfg(feature = "std")] mod unit_circle; #[cfg(feature = "std")] mod unit_sphere; #[cfg(feature = "std")] mod weibull; + +#[deprecated(since = "0.8.0", note = "use rand::distributions::{WeightedIndex, WeightedError} instead")] #[cfg(feature = "alloc")] pub mod weighted; +#[cfg(feature = "alloc")] mod weighted_index; mod float; #[doc(hidden)] diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 9b1f91e2e61..20248f51375 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -8,15 +8,10 @@ //! Weighted index sampling //! -//! The [`WeightedIndex`] distribution allows `O(log N)` sampling from a -//! sequence of weights. As the name implies, the result is the index in that -//! sequence, which may be used to look up an associated value. -//! -//! Note also that the `rand_distr` crate provides -//! `rand_distr::alias_method::WeightedIndex`, which allows `O(1)` sampling; -//! this distribution however has a much greater set-up cost, thus is only -//! recommended where *many* samples are required. -// TODO: link alias_method impl when published in rand_distr +//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and +//! [`crate::distributions::WeightedError`] instead. + +pub use super::{WeightedIndex, WeightedError}; #[allow(missing_docs)] #[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] @@ -51,397 +46,3 @@ pub mod alias_method { #[cfg(not(target_os = "emscripten"))] impl_weight!(u128, i128,); } - -use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; -use crate::distributions::Distribution; -use crate::Rng; -use core::cmp::PartialOrd; -use core::fmt; - -// Note that this whole module is only imported if feature="alloc" is enabled. -#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec; - -/// A distribution using weighted sampling to pick a discretely selected -/// item. -/// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// value of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform`] exists. -/// -/// # Performance -/// -/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. -/// -/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains, this might cause additional allocations, though for primitive -/// types, ['Uniform`] doesn't allocate any memory. -/// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. -/// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementaiton of `Uniform::sample`. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; -/// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`Uniform`]: crate::distributions::uniform::Uniform -/// [`RngCore`]: crate::RngCore -#[derive(Debug, Clone)] -pub struct WeightedIndex { - cumulative_weights: Vec, - total_weight: X, - weight_distribution: X::Sampler, -} - -impl WeightedIndex { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> - where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - - let zero = ::default(); - if total_weight < zero { - return Err(WeightedError::InvalidWeight); - } - - let mut weights = Vec::::with_capacity(iter.size_hint().0); - for w in iter { - if *w.borrow() < zero { - return Err(WeightedError::InvalidWeight); - } - weights.push(total_weight.clone()); - total_weight += w.borrow(); - } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - let distr = X::Sampler::new(zero, total_weight.clone()); - - Ok(WeightedIndex { - cumulative_weights: weights, - total_weight, - weight_distribution: distr, - }) - } - - /// Update a subset of weights, without changing the number of weights. - /// - /// `new_weights` must be sorted by the index. - /// - /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. No allocations are performed, unless the weight type `X` uses - /// allocation internally. - /// - /// In case of error, `self` is not modified. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> - + for<'a> ::core::ops::SubAssign<&'a X> - + Clone - + Default { - if new_weights.is_empty() { - return Ok(()); - } - - let zero = ::default(); - - let mut total_weight = self.total_weight.clone(); - - // Check for errors first, so we don't modify `self` in case something - // goes wrong. - let mut prev_i = None; - for &(i, w) in new_weights { - if let Some(old_i) = prev_i { - if old_i >= i { - return Err(WeightedError::InvalidWeight); - } - } - if *w < zero { - return Err(WeightedError::InvalidWeight); - } - if i >= self.cumulative_weights.len() + 1 { - return Err(WeightedError::TooMany); - } - - let mut old_w = if i < self.cumulative_weights.len() { - self.cumulative_weights[i].clone() - } else { - self.total_weight.clone() - }; - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; - } - - total_weight -= &old_w; - total_weight += w; - prev_i = Some(i); - } - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - - // Update the weights. Because we checked all the preconditions in the - // previous loop, this should never panic. - let mut iter = new_weights.iter(); - - let mut prev_weight = zero.clone(); - let mut next_new_weight = iter.next(); - let &(first_new_index, _) = next_new_weight.unwrap(); - let mut cumulative_weight = if first_new_index > 0 { - self.cumulative_weights[first_new_index - 1].clone() - } else { - zero.clone() - }; - for i in first_new_index..self.cumulative_weights.len() { - match next_new_weight { - Some(&(j, w)) if i == j => { - cumulative_weight += w; - next_new_weight = iter.next(); - } - _ => { - let mut tmp = self.cumulative_weights[i].clone(); - tmp -= &prev_weight; // We know this is positive. - cumulative_weight += &tmp; - } - } - prev_weight = cumulative_weight.clone(); - core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); - } - - self.total_weight = total_weight; - self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); - - Ok(()) - } -} - -impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd -{ - fn sample(&self, rng: &mut R) -> usize { - use ::core::cmp::Ordering; - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights - .binary_search_by(|w| { - if *w <= chosen_weight { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_err() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weightedindex() { - let mut r = crate::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!( - WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) - .unwrap() - .sample(&mut r), - 4 - ); - } - - assert_eq!( - WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightedError::NoItem - ); - assert_eq!( - WeightedIndex::new(&[0]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - #[test] - fn test_update_weights() { - let data = [ - ( - &[10u32, 2, 3, 4][..], - &[(1, &100), (2, &4)][..], // positive change - &[10, 100, 4, 4][..], - ), - ( - &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], - &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element - &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], - ), - ]; - - for (weights, update, expected_weights) in data.iter() { - let total_weight = weights.iter().sum::(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(update).unwrap(); - let expected_total_weight = expected_weights.iter().sum::(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); - } - } - - #[test] - fn value_stability() { - fn test_samples( - weights: I, buf: &mut [usize], expected: &[usize], - ) where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(701); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); - } -} - -/// Error type returned from `WeightedIndex::new`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, - - /// A weight is either less than zero, greater than the supported maximum or - /// otherwise invalid. - InvalidWeight, - - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, -} - -#[cfg(feature = "std")] -impl ::std::error::Error for WeightedError {} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - WeightedError::NoItem => write!(f, "No weights provided."), - WeightedError::InvalidWeight => write!(f, "A weight is invalid."), - WeightedError::AllWeightsZero => write!(f, "All weights are zero."), - WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"), - } - } -} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs new file mode 100644 index 00000000000..5a3ecaa2099 --- /dev/null +++ b/src/distributions/weighted_index.rs @@ -0,0 +1,404 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted index sampling + +use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distributions::Distribution; +use crate::Rng; +use core::cmp::PartialOrd; +use core::fmt; + +// Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec; + +/// A distribution using weighted sampling of discrete items +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// value of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform`] exists. +/// +/// # Performance +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. As an alternative, +/// [`rand_distr::weighted::alias_method`](https://docs.rs/rand_distr/*/rand_distr/weighted/alias_method/index.html) +/// supports `O(1)` sampling, but with much higher initialisation cost. +/// +/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains, this might cause additional allocations, though for primitive +/// types, ['Uniform`] doesn't allocate any memory. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementaiton of `Uniform::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform`]: crate::distributions::uniform::Uniform +/// [`RngCore`]: crate::RngCore +#[derive(Debug, Clone)] +pub struct WeightedIndex { + cumulative_weights: Vec, + total_weight: X, + weight_distribution: X::Sampler, +} + +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform`] exists. + /// + /// Returns an error if the iterator is empty, if any weight is `< 0`, or + /// if its total value is 0. + /// + /// [`Uniform`]: crate::distributions::uniform::Uniform + pub fn new(weights: I) -> Result, WeightedError> + where + I: IntoIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + + let zero = ::default(); + if total_weight < zero { + return Err(WeightedError::InvalidWeight); + } + + let mut weights = Vec::::with_capacity(iter.size_hint().0); + for w in iter { + if *w.borrow() < zero { + return Err(WeightedError::InvalidWeight); + } + weights.push(total_weight.clone()); + total_weight += w.borrow(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + let distr = X::Sampler::new(zero, total_weight.clone()); + + Ok(WeightedIndex { + cumulative_weights: weights, + total_weight, + weight_distribution: distr, + }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = ::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } + if *w < zero { + return Err(WeightedError::InvalidWeight); + } + if i >= self.cumulative_weights.len() + 1 { + return Err(WeightedError::TooMany); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + } + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + + Ok(()) + } +} + +impl Distribution for WeightedIndex +where X: SampleUniform + PartialOrd +{ + fn sample(&self, rng: &mut R) -> usize { + use ::core::cmp::Ordering; + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights + .binary_search_by(|w| { + if *w <= chosen_weight { + Ordering::Less + } else { + Ordering::Greater + } + }) + .unwrap_err() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weightedindex() { + let mut r = crate::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); + assert_eq!( + WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) + .unwrap() + .sample(&mut r), + 4 + ); + } + + assert_eq!( + WeightedIndex::new(&[10][0..0]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(&[0]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(&[-10]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_update_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for (weights, update, expected_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } + } + + #[test] + fn value_stability() { + fn test_samples( + weights: I, buf: &mut [usize], expected: &[usize], + ) where + I: IntoIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + assert_eq!(buf.len(), expected.len()); + let distr = WeightedIndex::new(weights).unwrap(); + let mut rng = crate::test::rng(701); + for r in buf.iter_mut() { + *r = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + let mut buf = [0; 10]; + test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ + 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, + ]); + test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ + 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, + ]); + test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ + 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, + ]); + } +} + +/// Error type returned from `WeightedIndex::new`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { + /// The provided weight collection contains no items. + NoItem, + + /// A weight is either less than zero, greater than the supported maximum or + /// otherwise invalid. + InvalidWeight, + + /// All items in the provided weight collection are zero. + AllWeightsZero, + + /// Too many weights are provided (length greater than `u32::MAX`) + TooMany, +} + +#[cfg(feature = "std")] +impl ::std::error::Error for WeightedError {} + +impl fmt::Display for WeightedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + WeightedError::NoItem => write!(f, "No weights provided."), + WeightedError::InvalidWeight => write!(f, "A weight is invalid."), + WeightedError::AllWeightsZero => write!(f, "All weights are zero."), + WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"), + } + } +}