Skip to content

Commit

Permalink
Add WeightedError
Browse files Browse the repository at this point in the history
  • Loading branch information
sicking committed Jul 12, 2018
1 parent 40d8c39 commit 56a339d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ use Rng;
#[doc(inline)] pub use self::uniform::Uniform;
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
#[cfg(feature="alloc")]
#[doc(inline)] pub use self::weighted::WeightedIndex;
#[doc(inline)] pub use self::weighted::{WeightedIndex, WeightedError};
#[cfg(feature="std")]
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
#[cfg(feature="std")]
Expand Down
48 changes: 37 additions & 11 deletions src/distributions/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use Rng;
use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd;
use ::{Error, ErrorKind};
use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::vec::Vec;
Expand Down Expand Up @@ -63,34 +63,34 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
///
/// [`Distribution`]: trait.Distribution.html
/// [`Uniform<X>`]: struct.Uniform.html
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
where I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> +
Clone +
Default {
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next()
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
.ok_or(WeightedError::NoItem)?
.borrow()
.clone();

let zero = <X as Default>::default();
if total_weight < zero {
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
return Err(WeightedError::NegativeWeight);
}

let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
if *w.borrow() < zero {
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
return Err(WeightedError::NegativeWeight);
}
weights.push(total_weight.clone());
total_weight += w.borrow();
}

if total_weight == zero {
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
return Err(WeightedError::AllWeightsZero);
}
let distr = X::Sampler::new(zero, total_weight);

Expand Down Expand Up @@ -161,10 +161,36 @@ mod test {
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
}

assert!(WeightedIndex::new(&[10][0..0]).is_err());
assert!(WeightedIndex::new(&[0]).is_err());
assert!(WeightedIndex::new(&[10, 20, -1, 30]).is_err());
assert!(WeightedIndex::new(&[-10, 20, 1, 30]).is_err());
assert!(WeightedIndex::new(&[-10]).is_err());
assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightedError {
NoItem,
NegativeWeight,
AllWeightsZero,
}

#[cfg(feature="std")]
impl ::std::error::Error for WeightedError {
fn description(&self) -> &str {
match *self {
WeightedError::NoItem => "No items found",
WeightedError::NegativeWeight => "Item has negative weight",
WeightedError::AllWeightsZero => "All items had weight zero",
}
}
fn cause(&self) -> Option<&::std::error::Error> { None }
}

impl fmt::Display for WeightedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use ::std::error::Error;
write!(f, "{}", self.description())
}
}
27 changes: 15 additions & 12 deletions src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#[cfg(feature="std")] use std::collections::HashMap;
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap;

#[cfg(feature = "alloc")] use distributions::WeightedError;

use super::Rng;
#[cfg(feature="alloc")] use distributions::uniform::{SampleUniform, SampleBorrow};

Expand Down Expand Up @@ -109,7 +111,7 @@ pub trait SliceRandom {
/// ```
/// [`choose`]: trait.SliceRandom.html#method.choose
#[cfg(feature = "alloc")]
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Option<&Self::Item>
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError>
where R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
Expand All @@ -129,7 +131,7 @@ pub trait SliceRandom {
/// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut
/// [`choose_weighted`]: trait.SliceRandom.html#method.choose_weighted
#[cfg(feature = "alloc")]
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item>
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError>
where R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
Expand Down Expand Up @@ -327,7 +329,7 @@ impl<T> SliceRandom for [T] {
}

#[cfg(feature = "alloc")]
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Option<&Self::Item>
fn choose_weighted<R, F, B, X>(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError>
where R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
Expand All @@ -337,12 +339,12 @@ impl<T> SliceRandom for [T] {
Clone +
Default {
use distributions::{Distribution, WeightedIndex};
WeightedIndex::new(self.iter().map(weight)).ok()
.map(|distr| &self[distr.sample(rng)])
let distr = WeightedIndex::new(self.iter().map(weight))?;
Ok(&self[distr.sample(rng)])
}

#[cfg(feature = "alloc")]
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item>
fn choose_weighted_mut<R, F, B, X>(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError>
where R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
Expand All @@ -352,9 +354,8 @@ impl<T> SliceRandom for [T] {
Clone +
Default {
use distributions::{Distribution, WeightedIndex};
WeightedIndex::new(self.iter().map(weight)).ok()
.map(|distr| distr.sample(rng))
.map(move |ix| &mut self[ix])
let distr = WeightedIndex::new(self.iter().map(weight))?;
Ok(&mut self[distr.sample(rng)])
}

fn shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized
Expand Down Expand Up @@ -868,8 +869,10 @@ mod test {

// Check error cases
let empty_slice = &mut [10][0..0];
assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), None);
assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), None);
assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), None);
assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), Err(WeightedError::NoItem));
assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), Err(WeightedError::NoItem));
assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), Err(WeightedError::AllWeightsZero));
assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight));
assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight));
}
}

0 comments on commit 56a339d

Please sign in to comment.