diff --git a/benches/distributions.rs b/benches/distributions.rs index 478619c5112..4e215e857fa 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -115,6 +115,11 @@ distr_int!(distr_binomial, u64, Binomial::new(20, 0.7)); distr_int!(distr_poisson, u64, Poisson::new(4.0)); distr!(distr_bernoulli, bool, Bernoulli::new(0.18)); +// Weighted +distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); // construct and sample from a range macro_rules! gen_range_int { diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 9d093f1d509..888c514a290 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -73,6 +73,8 @@ //! numbers of the `char` type; in contrast [`Standard`] may sample any valid //! `char`. //! +//! [`WeightedIndex`] can be used to do weighted sampling from a set of items, +//! such as from an array. //! //! # Non-uniform probability distributions //! @@ -167,12 +169,15 @@ //! [`Uniform`]: struct.Uniform.html //! [`Uniform::new`]: struct.Uniform.html#method.new //! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive +//! [`WeightedIndex`]: struct.WeightedIndex.html use Rng; #[doc(inline)] pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; #[doc(inline)] pub use self::float::{OpenClosed01, Open01}; +#[cfg(feature="alloc")] +#[doc(inline)] pub use self::weighted::WeightedIndex; #[cfg(feature="std")] #[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT}; #[cfg(feature="std")] @@ -192,6 +197,8 @@ use Rng; #[doc(inline)] pub use self::dirichlet::Dirichlet; pub mod uniform; +#[cfg(feature="alloc")] +#[doc(hidden)] pub mod weighted; #[cfg(feature="std")] #[doc(hidden)] pub mod gamma; #[cfg(feature="std")] @@ -372,6 +379,8 @@ pub struct Standard; /// A value with a particular weight for use with `WeightedChoice`. +#[deprecated(since="0.6.0", note="use WeightedIndex instead")] +#[allow(deprecated)] #[derive(Copy, Clone, Debug)] pub struct Weighted { /// The numerical weight of this item @@ -382,34 +391,18 @@ pub struct Weighted { /// A distribution that selects from a finite collection of weighted items. /// -/// Each item has an associated weight that influences how likely it -/// is to be chosen: higher weight is more likely. -/// -/// The `Clone` restriction is a limitation of the `Distribution` trait. -/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can -/// store references or indices into another vector. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Weighted, WeightedChoice, Distribution}; -/// -/// let mut items = vec!(Weighted { weight: 2, item: 'a' }, -/// Weighted { weight: 4, item: 'b' }, -/// Weighted { weight: 1, item: 'c' }); -/// let wc = WeightedChoice::new(&mut items); -/// let mut rng = rand::thread_rng(); -/// for _ in 0..16 { -/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice. -/// println!("{}", wc.sample(&mut rng)); -/// } -/// ``` +/// Deprecated: use [`WeightedIndex`] instead. +/// [`WeightedIndex`]: distributions/struct.WeightedIndex.html +#[deprecated(since="0.6.0", note="use WeightedIndex instead")] +#[allow(deprecated)] #[derive(Debug)] pub struct WeightedChoice<'a, T:'a> { items: &'a mut [Weighted], weight_range: Uniform, } +#[deprecated(since="0.6.0", note="use WeightedIndex instead")] +#[allow(deprecated)] impl<'a, T: Clone> WeightedChoice<'a, T> { /// Create a new `WeightedChoice`. /// @@ -447,6 +440,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> { } } +#[deprecated(since="0.6.0", note="use WeightedIndex instead")] +#[allow(deprecated)] impl<'a, T: Clone> Distribution for WeightedChoice<'a, T> { fn sample(&self, rng: &mut R) -> T { // we want to find the first element that has cumulative @@ -556,9 +551,11 @@ fn ziggurat( #[cfg(test)] mod tests { use rngs::mock::StepRng; + #[allow(deprecated)] use super::{WeightedChoice, Weighted, Distribution}; #[test] + #[allow(deprecated)] fn test_weighted_choice() { // this makes assumptions about the internal implementation of // WeightedChoice. It may fail when the implementation in @@ -618,6 +615,7 @@ mod tests { } #[test] + #[allow(deprecated)] fn test_weighted_clone_initialization() { let initial : Weighted = Weighted {weight: 1, item: 1}; let clone = initial.clone(); @@ -626,6 +624,7 @@ mod tests { } #[test] #[should_panic] + #[allow(deprecated)] fn test_weighted_clone_change_weight() { let initial : Weighted = Weighted {weight: 1, item: 1}; let mut clone = initial.clone(); @@ -634,6 +633,7 @@ mod tests { } #[test] #[should_panic] + #[allow(deprecated)] fn test_weighted_clone_change_item() { let initial : Weighted = Weighted {weight: 1, item: 1}; let mut clone = initial.clone(); @@ -643,15 +643,18 @@ mod tests { } #[test] #[should_panic] + #[allow(deprecated)] fn test_weighted_choice_no_items() { WeightedChoice::::new(&mut []); } #[test] #[should_panic] + #[allow(deprecated)] fn test_weighted_choice_zero_weight() { WeightedChoice::new(&mut [Weighted { weight: 0, item: 0}, Weighted { weight: 0, item: 1}]); } #[test] #[should_panic] + #[allow(deprecated)] fn test_weighted_choice_weight_overflows() { let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow WeightedChoice::new(&mut [Weighted { weight: x, item: 0 }, diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs new file mode 100644 index 00000000000..c77e485b95f --- /dev/null +++ b/src/distributions/weighted.rs @@ -0,0 +1,182 @@ +// Copyright 2017 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// https://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use Rng; +use distributions::Distribution; +use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; +use ::core::cmp::PartialOrd; +use ::{Error, ErrorKind}; + +// Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature="std"))] use alloc::Vec; + +/// A distribution using weighted sampling to pick an discretely selected item. +/// +/// When a `WeightedIndex` is sampled from, it returns the index +/// of a random element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// value of the element. The weights can use any type `X` for which an +/// implementaiton of [`Uniform`] exists. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +#[derive(Debug, Clone)] +pub struct WeightedIndex { + cumulative_weights: Vec, + weight_distribution: X::Sampler, +} + +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementaiton of [`Uniform`] exists. + /// + /// Returns an error if the iterator is empty, or its total value is 0. + /// + /// # Panics + /// + /// If a value in the iterator is `< 0`. + /// + /// [`Distribution`]: trait.Distribution.html + /// [`Uniform`]: struct.Uniform.html + pub fn new(weights: I) -> Result, Error> + where I: IntoIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + + Clone + + Default { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next() + .ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))? + .borrow() + .clone(); + + let zero = ::default(); + let weights = iter.map(|w| { + assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new"); + let prev_weight = total_weight.clone(); + total_weight += w.borrow(); + prev_weight + }).collect::>(); + + if total_weight == zero { + return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new")); + } + let distr = X::Sampler::new(zero, total_weight); + + Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) + } +} + +impl Distribution for WeightedIndex where + X: SampleUniform + PartialOrd { + fn sample(&self, rng: &mut R) -> usize { + let chosen_weight = self.weight_distribution.sample(rng); + // Invariants: indexes in range [start, end] (inclusive) are candidate indexes + // cumulative_weights[start-1] <= chosen_weight + // chosen_weight < cumulative_weights[end] + // The returned index is the first one whose value is >= chosen_weight + let mut start = 0usize; + let mut end = self.cumulative_weights.len(); + while start < end { + let mid = (start + end) / 2; + if chosen_weight >= * unsafe { self.cumulative_weights.get_unchecked(mid) } { + start = mid + 1; + } else { + end = mid; + } + } + debug_assert_eq!(start, end); + start + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature="std")] + use core::panic::catch_unwind; + + #[test] + fn test_weightedindex() { + let mut r = ::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + assert!(WeightedIndex::new(&[10][0..0]).is_err()); + assert!(WeightedIndex::new(&[0]).is_err()); + } + + #[test] + #[cfg(all(feature="std", + not(target_arch = "wasm32"), + not(target_arch = "asmjs")))] + fn test_weighted_assertions() { + assert!(catch_unwind(|| WeightedIndex::new(&[1, 2, 3])).is_ok()); + assert!(catch_unwind(|| WeightedIndex::new(&[10, -1, 10])).is_err()); + assert!(catch_unwind(|| WeightedIndex::new(&[1, -1])).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index e75b72992bd..aa1123a49bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,10 +134,6 @@ //! //! For more slice/sequence related functionality, look in the [`seq` module]. //! -//! There is also [`distributions::WeightedChoice`], which can be used to pick -//! elements at random with some probability. But it does not work well at the -//! moment and is going through a redesign. -//! //! //! # Error handling //! @@ -187,7 +183,6 @@ //! //! //! [`distributions` module]: distributions/index.html -//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html //! [`EntropyRng`]: rngs/struct.EntropyRng.html //! [`Error`]: struct.Error.html //! [`gen_range`]: trait.Rng.html#method.gen_range diff --git a/src/seq.rs b/src/seq.rs index 53476bf92ef..71ae416b988 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -22,6 +22,7 @@ use super::Rng; +#[cfg(feature="alloc")] use distributions::uniform::{SampleUniform, SampleBorrow}; /// Extension trait on slices, providing random mutation and sampling methods. /// @@ -91,6 +92,53 @@ pub trait SliceRandom { fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized; + /// Similar to [`choose`], where the likelihood of each outcome may be + /// specified. The specified function `weight` maps items `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = thread_rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' + /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); + /// ``` + /// [`choose`]: trait.SliceRandom.html#method.choose + #[cfg(feature = "alloc")] + fn choose_weighted(&self, rng: &mut R, weight: F) -> Option<&Self::Item> + where R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd + + Clone + + Default; + + /// Similar to [`choose_mut`], where the likelihood of each outcome may be + /// specified. The specified function `weight` maps items `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// See also [`choose_weighted`]. + /// + /// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut + /// [`choose_weighted`]: trait.SliceRandom.html#method.choose_weighted + #[cfg(feature = "alloc")] + fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item> + where R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd + + Clone + + Default; + /// Shuffle a mutable slice in place. /// /// Depending on the implementation, complexity is expected to be `O(1)`. @@ -141,6 +189,7 @@ pub trait IteratorRandom: Iterator + Sized { /// Complexity is `O(n)`, where `n` is the length of the iterator. /// This likely consumes multiple random numbers, but the exact number /// is unspecified. + /// /// [`choose`]: trait.SliceRandom.html#method.choose /// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut fn choose(mut self, rng: &mut R) -> Option @@ -277,6 +326,39 @@ impl SliceRandom for [T] { } } + #[cfg(feature = "alloc")] + fn choose_weighted(&self, rng: &mut R, weight: F) -> Option<&Self::Item> + where R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd + + Clone + + Default { + use distributions::weighted::WeightedIndex; + use distributions::Distribution; + WeightedIndex::new(self.iter().map(weight)).ok() + .map(|distr| &self[distr.sample(rng)]) + } + + #[cfg(feature = "alloc")] + fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item> + where R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd + + Clone + + Default { + use distributions::weighted::WeightedIndex; + use distributions::Distribution; + WeightedIndex::new(self.iter().map(weight)).ok() + .map(|distr| distr.sample(rng)) + .map(move |ix| &mut self[ix]) + } + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized { for i in (1..self.len()).rev() { @@ -738,4 +820,52 @@ mod test { } } } + + #[test] + #[cfg(feature = "alloc")] + fn test_weighted() { + let mut r = ::test::rng(406); + const N_REPS: u32 = 3000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // choose_weighted + fn get_weight(item: &(u32, T)) -> u32 { + item.0 + } + let mut chosen = [0i32; 14]; + let mut items = [(0u32, 0usize); 14]; // (weight, index) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], i); + } + for _ in 0..N_REPS { + let item = items.choose_weighted(&mut r, get_weight).unwrap(); + chosen[item.1] += 1; + } + verify(chosen); + + // choose_weighted_mut + let mut items = [(0u32, 0i32); 14]; // (weight, count) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], 0); + } + for _ in 0..N_REPS { + items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; + } + for (ch, item) in chosen.iter_mut().zip(items.iter()) { + *ch = item.1; + } + verify(chosen); + } }