Skip to content

Commit

Permalink
Implement WeightedIndex, SliceRandom::choose_weighted and SliceRandom…
Browse files Browse the repository at this point in the history
…::choose_weighted_mut
  • Loading branch information
sicking committed Jun 27, 2018
1 parent af1303c commit b1169d4
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 28 deletions.
44 changes: 22 additions & 22 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ use Rng;
#[doc(inline)] pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
#[cfg(feature="alloc")]
#[doc(inline)] pub use self::weighted::WeightedIndex;
#[cfg(feature="std")]
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
#[cfg(feature="std")]
Expand All @@ -192,6 +194,8 @@ use Rng;
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="alloc")]
#[doc(hidden)] pub mod weighted;
#[cfg(feature="std")]
#[doc(hidden)] pub mod gamma;
#[cfg(feature="std")]
Expand Down Expand Up @@ -372,6 +376,8 @@ pub struct Standard;


/// A value with a particular weight for use with `WeightedChoice`.
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Copy, Clone, Debug)]
pub struct Weighted<T> {
/// The numerical weight of this item
Expand All @@ -382,34 +388,18 @@ pub struct Weighted<T> {

/// A distribution that selects from a finite collection of weighted items.
///
/// Each item has an associated weight that influences how likely it
/// is to be chosen: higher weight is more likely.
///
/// The `Clone` restriction is a limitation of the `Distribution` trait.
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
/// store references or indices into another vector.
///
/// # Example
///
/// ```
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
///
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
/// Weighted { weight: 4, item: 'b' },
/// Weighted { weight: 1, item: 'c' });
/// let wc = WeightedChoice::new(&mut items);
/// let mut rng = rand::thread_rng();
/// for _ in 0..16 {
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
/// println!("{}", wc.sample(&mut rng));
/// }
/// ```
/// Deprecated: use [`WeightedIndex`] instead.
/// [`WeightedIndex`]: distributions/struct.WeightedIndex.html
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub struct WeightedChoice<'a, T:'a> {
items: &'a mut [Weighted<T>],
weight_range: Uniform<u32>,
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> WeightedChoice<'a, T> {
/// Create a new `WeightedChoice`.
///
Expand Down Expand Up @@ -447,6 +437,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
}
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
// we want to find the first element that has cumulative
Expand Down Expand Up @@ -556,9 +548,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
#[cfg(test)]
mod tests {
use rngs::mock::StepRng;
#[allow(deprecated)]
use super::{WeightedChoice, Weighted, Distribution};

#[test]
#[allow(deprecated)]
fn test_weighted_choice() {
// this makes assumptions about the internal implementation of
// WeightedChoice. It may fail when the implementation in
Expand Down Expand Up @@ -618,6 +612,7 @@ mod tests {
}

#[test]
#[allow(deprecated)]
fn test_weighted_clone_initialization() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let clone = initial.clone();
Expand All @@ -626,6 +621,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_weight() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -634,6 +630,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_item() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -643,15 +640,18 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
Weighted { weight: 0, item: 1}]);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_weight_overflows() {
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
Expand Down
179 changes: 179 additions & 0 deletions src/distributions/weighted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! A distribution using weighted sampling to pick an discretely selected item.
//!
//! When a `WeightedIndex` is sampled from, it returns the index
//! of a random element from the iterator used when the `WeightedIndex` was
//! created. The chance of a given element being picked is proportional to the
//! value of the element. The weights can use any type `X` for which an
//! implementaiton of [`Uniform<X>`] exists.
//!
//! # Example
//!
//! ```
//! use rand::prelude::*;
//! use rand::distributions::WeightedIndex;
//!
//! let choices = ['a', 'b', 'c'];
//! let weights = [2, 1, 1];
//! let dist = WeightedIndex::new(&weights).unwrap();
//! let mut rng = thread_rng();
//! for _ in 0..100 {
//! // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
//! println!("{}", choices[dist.sample(&mut rng)]);
//! }
//!
//! let items = [('a', 0), ('b', 3), ('c', 7)];
//! let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
//! for _ in 0..100 {
//! // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
//! println!("{}", items[dist2.sample(&mut rng)].0);
//! }
//! ```

use Rng;
use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd;
use ::{Error, ErrorKind};

#[cfg(feature = "alloc")]
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
weight_distribution: X::Sampler,
}

impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// Creates a new a [`WeightedIndex`] [`Distribution`] using the values
/// in `weights`. The weights can use any type `X` for which an
/// implementaiton of [`Uniform<X>`] exists.
///
/// Returns an error if the iterator is empty, or its total value is 0.
///
/// # Panics
///
/// If a value in the iterator is `< 0`.
///
/// [`WeightedIndex`]: struct.WeightedIndex.html
/// [`Distribution`]: trait.Distribution.html
/// [`Uniform<X>`]: struct.Uniform.html
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
where I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> +
Clone +
Default {
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next()
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
.borrow()
.clone();

let zero = <X as Default>::default();
let weights = iter.map(|w| {
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new");
let prev_weight = total_weight.clone();
total_weight += w.borrow();
prev_weight
}).collect::<Vec<X>>();

if total_weight == zero {
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
}
let distr = X::Sampler::new(zero, total_weight);

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
}
}

impl<X> Distribution<usize> for WeightedIndex<X> where
X: SampleUniform + PartialOrd {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let chosen_weight = self.weight_distribution.sample(rng);
// Invariants: indexes in range [start, end] (inclusive) are candidate indexes
// cumulative_weights[start-1] <= chosen_weight
// chosen_weight < cumulative_weights[end]
// The returned index is the first one whose value is >= chosen_weight
let mut start = 0usize;
let mut end = self.cumulative_weights.len();
while start < end {
let mid = (start + end) / 2;
if chosen_weight >= * unsafe { self.cumulative_weights.get_unchecked(mid) } {
start = mid + 1;
} else {
end = mid;
}
}
debug_assert_eq!(start, end);
start
}
}

#[cfg(test)]
mod test {
use super::*;
#[cfg(feature="std")]
use core::panic::catch_unwind;

#[test]
fn test_weightedindex() {
let mut r = ::test::rng(700);
const N_REPS: u32 = 5000;
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let total_weight = weights.iter().sum::<u32>() as f32;

let verify = |result: [i32; 14]| {
for (i, count) in result.iter().enumerate() {
let exp = (weights[i] * N_REPS) as f32 / total_weight;
let mut err = (*count as f32 - exp).abs();
if err != 0.0 {
err /= exp;
}
assert!(err <= 0.25);
}
};

// WeightedIndex from vec
let mut chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from slice
chosen = [0i32; 14];
let distr = WeightedIndex::new(&weights[..]).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from iterator
chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.iter()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);
}

#[test]
#[cfg(all(feature="std",
not(target_arch = "wasm32"),
not(target_arch = "asmjs")))]
fn test_weighted_assertions() {
assert!(catch_unwind(|| WeightedIndex::new(&[1, 2, 3])).is_ok());
assert!(catch_unwind(|| WeightedIndex::new(&[10, -1, 10])).is_err());
assert!(catch_unwind(|| WeightedIndex::new(&[1, -1])).is_err());
}
}
5 changes: 0 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@
//!
//! For more slice/sequence related functionality, look in the [`seq` module].
//!
//! There is also [`distributions::WeightedChoice`], which can be used to pick
//! elements at random with some probability. But it does not work well at the
//! moment and is going through a redesign.
//!
//!
//! # Error handling
//!
Expand Down Expand Up @@ -187,7 +183,6 @@
//!
//!
//! [`distributions` module]: distributions/index.html
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
//! [`Error`]: struct.Error.html
//! [`gen_range`]: trait.Rng.html#method.gen_range
Expand Down
Loading

0 comments on commit b1169d4

Please sign in to comment.