Skip to content

Commit

Permalink
Implement WeightedIndex, SliceRandom::choose_weighted and SliceRandom…
Browse files Browse the repository at this point in the history
…::choose_weighted_mut
  • Loading branch information
sicking committed Jun 27, 2018
1 parent af1303c commit 261d7b0
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 28 deletions.
5 changes: 5 additions & 0 deletions benches/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
distr_int!(distr_poisson, u64, Poisson::new(4.0));
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));

// Weighted
distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());

// construct and sample from a range
macro_rules! gen_range_int {
Expand Down
47 changes: 25 additions & 22 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
//! numbers of the `char` type; in contrast [`Standard`] may sample any valid
//! `char`.
//!
//! [`WeightedIndex`] can be used to do weighted sampling from a set of items,
//! such as from an array.
//!
//! # Non-uniform probability distributions
//!
Expand Down Expand Up @@ -167,12 +169,15 @@
//! [`Uniform`]: struct.Uniform.html
//! [`Uniform::new`]: struct.Uniform.html#method.new
//! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive
//! [`WeightedIndex`]: struct.WeightedIndex.html

use Rng;

#[doc(inline)] pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
#[cfg(feature="alloc")]
#[doc(inline)] pub use self::weighted::WeightedIndex;
#[cfg(feature="std")]
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
#[cfg(feature="std")]
Expand All @@ -192,6 +197,8 @@ use Rng;
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="alloc")]
#[doc(hidden)] pub mod weighted;
#[cfg(feature="std")]
#[doc(hidden)] pub mod gamma;
#[cfg(feature="std")]
Expand Down Expand Up @@ -372,6 +379,8 @@ pub struct Standard;


/// A value with a particular weight for use with `WeightedChoice`.
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Copy, Clone, Debug)]
pub struct Weighted<T> {
/// The numerical weight of this item
Expand All @@ -382,34 +391,18 @@ pub struct Weighted<T> {

/// A distribution that selects from a finite collection of weighted items.
///
/// Each item has an associated weight that influences how likely it
/// is to be chosen: higher weight is more likely.
///
/// The `Clone` restriction is a limitation of the `Distribution` trait.
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
/// store references or indices into another vector.
///
/// # Example
///
/// ```
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
///
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
/// Weighted { weight: 4, item: 'b' },
/// Weighted { weight: 1, item: 'c' });
/// let wc = WeightedChoice::new(&mut items);
/// let mut rng = rand::thread_rng();
/// for _ in 0..16 {
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
/// println!("{}", wc.sample(&mut rng));
/// }
/// ```
/// Deprecated: use [`WeightedIndex`] instead.
/// [`WeightedIndex`]: distributions/struct.WeightedIndex.html
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub struct WeightedChoice<'a, T:'a> {
items: &'a mut [Weighted<T>],
weight_range: Uniform<u32>,
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> WeightedChoice<'a, T> {
/// Create a new `WeightedChoice`.
///
Expand Down Expand Up @@ -447,6 +440,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
}
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
// we want to find the first element that has cumulative
Expand Down Expand Up @@ -556,9 +551,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
#[cfg(test)]
mod tests {
use rngs::mock::StepRng;
#[allow(deprecated)]
use super::{WeightedChoice, Weighted, Distribution};

#[test]
#[allow(deprecated)]
fn test_weighted_choice() {
// this makes assumptions about the internal implementation of
// WeightedChoice. It may fail when the implementation in
Expand Down Expand Up @@ -618,6 +615,7 @@ mod tests {
}

#[test]
#[allow(deprecated)]
fn test_weighted_clone_initialization() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let clone = initial.clone();
Expand All @@ -626,6 +624,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_weight() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -634,6 +633,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_item() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -643,15 +643,18 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
Weighted { weight: 0, item: 1}]);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_weight_overflows() {
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
Expand Down
182 changes: 182 additions & 0 deletions src/distributions/weighted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use Rng;
use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd;
use ::{Error, ErrorKind};

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::Vec;

/// A distribution using weighted sampling to pick an discretely selected item.
///
/// When a `WeightedIndex` is sampled from, it returns the index
/// of a random element from the iterator used when the `WeightedIndex` was
/// created. The chance of a given element being picked is proportional to the
/// value of the element. The weights can use any type `X` for which an
/// implementaiton of [`Uniform<X>`] exists.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
/// use rand::distributions::WeightedIndex;
///
/// let choices = ['a', 'b', 'c'];
/// let weights = [2, 1, 1];
/// let dist = WeightedIndex::new(&weights).unwrap();
/// let mut rng = thread_rng();
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// println!("{}", choices[dist.sample(&mut rng)]);
/// }
///
/// let items = [('a', 0), ('b', 3), ('c', 7)];
/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
/// for _ in 0..100 {
/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
/// println!("{}", items[dist2.sample(&mut rng)].0);
/// }
/// ```
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
weight_distribution: X::Sampler,
}

impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
/// in `weights`. The weights can use any type `X` for which an
/// implementaiton of [`Uniform<X>`] exists.
///
/// Returns an error if the iterator is empty, or its total value is 0.
///
/// # Panics
///
/// If a value in the iterator is `< 0`.
///
/// [`Distribution`]: trait.Distribution.html
/// [`Uniform<X>`]: struct.Uniform.html
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
where I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> +
Clone +
Default {
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next()
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
.borrow()
.clone();

let zero = <X as Default>::default();
let weights = iter.map(|w| {
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new");
let prev_weight = total_weight.clone();
total_weight += w.borrow();
prev_weight
}).collect::<Vec<X>>();

if total_weight == zero {
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
}
let distr = X::Sampler::new(zero, total_weight);

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
}
}

impl<X> Distribution<usize> for WeightedIndex<X> where
X: SampleUniform + PartialOrd {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let chosen_weight = self.weight_distribution.sample(rng);
// Invariants: indexes in range [start, end] (inclusive) are candidate indexes
// cumulative_weights[start-1] <= chosen_weight
// chosen_weight < cumulative_weights[end]
// The returned index is the first one whose value is >= chosen_weight
let mut start = 0usize;
let mut end = self.cumulative_weights.len();
while start < end {
let mid = (start + end) / 2;
if chosen_weight >= * unsafe { self.cumulative_weights.get_unchecked(mid) } {
start = mid + 1;
} else {
end = mid;
}
}
debug_assert_eq!(start, end);
start
}
}

#[cfg(test)]
mod test {
use super::*;
#[cfg(feature="std")]
use core::panic::catch_unwind;

#[test]
fn test_weightedindex() {
let mut r = ::test::rng(700);
const N_REPS: u32 = 5000;
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let total_weight = weights.iter().sum::<u32>() as f32;

let verify = |result: [i32; 14]| {
for (i, count) in result.iter().enumerate() {
let exp = (weights[i] * N_REPS) as f32 / total_weight;
let mut err = (*count as f32 - exp).abs();
if err != 0.0 {
err /= exp;
}
assert!(err <= 0.25);
}
};

// WeightedIndex from vec
let mut chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from slice
chosen = [0i32; 14];
let distr = WeightedIndex::new(&weights[..]).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from iterator
chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.iter()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

assert!(WeightedIndex::new(&[10][0..0]).is_err());
assert!(WeightedIndex::new(&[0]).is_err());
}

#[test]
#[cfg(all(feature="std",
not(target_arch = "wasm32"),
not(target_arch = "asmjs")))]
fn test_weighted_assertions() {
assert!(catch_unwind(|| WeightedIndex::new(&[1, 2, 3])).is_ok());
assert!(catch_unwind(|| WeightedIndex::new(&[10, -1, 10])).is_err());
assert!(catch_unwind(|| WeightedIndex::new(&[1, -1])).is_err());
}
}
5 changes: 0 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@
//!
//! For more slice/sequence related functionality, look in the [`seq` module].
//!
//! There is also [`distributions::WeightedChoice`], which can be used to pick
//! elements at random with some probability. But it does not work well at the
//! moment and is going through a redesign.
//!
//!
//! # Error handling
//!
Expand Down Expand Up @@ -187,7 +183,6 @@
//!
//!
//! [`distributions` module]: distributions/index.html
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
//! [`Error`]: struct.Error.html
//! [`gen_range`]: trait.Rng.html#method.gen_range
Expand Down
Loading

0 comments on commit 261d7b0

Please sign in to comment.