From 1c9c65a1406bd520734abd88f8ccd8643c7a31d8 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 26 May 2018 18:42:34 +0100 Subject: [PATCH 01/15] Sequence functionality: API revisions --- src/seq.rs | 154 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/src/seq.rs b/src/seq.rs index 68f7ab08edc..b8ecf493c23 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -9,6 +9,8 @@ // except according to those terms. //! Functions for randomly accessing and sampling sequences. +//! +//! TODO: module doc use super::Rng; @@ -19,6 +21,158 @@ use super::Rng; #[cfg(not(feature="std"))] use alloc::Vec; + +/// Extension trait on slices, providing random mutation and sampling methods. +/// +/// An implementation is provided for slices. This may also be implementable for +/// other types. +pub trait SliceExt { + /// The element type. + type Item; + + /// Returns a reference to one random element of the slice, or `None` if the + /// slice is empty. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + fn choose(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized; + + /// Returns a mutable reference to one random element of the slice, or + /// `None` if the slice is empty. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized; + + /// Produces an iterator that chooses `amount` elements from the slice at + /// random without repeating any. + /// + /// In case this API is not sufficiently flexible, use `sample_indices` then + /// apply the indices to the slice. + /// + /// Although the elements are selected randomly, the order of returned + /// elements is neither stable nor fully random. If random ordering is + /// desired, either use `partial_shuffle` or use this method and shuffle + /// the result. If stable order is desired, use `sample_indices`, sort the + /// result, then apply to the slice. + /// + /// Complexity is expected to be the same as `sample_indices`. + fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> + where R: Rng + ?Sized; + + /// Shuffle a slice in place. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to chose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = values.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// Depending on the implementation, complexity is expected to be `O(m)`, + /// where `m = amount`. + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) + -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized; +} + +/// Extension trait on iterators, providing random sampling methods. +pub trait IteratorExt: Iterator + Sized { + /// Choose one element at random from the iterator. + /// + /// Returns `None` if and only if the iterator is empty. + /// + /// Complexity is `O(n)`, where `n` is the length of the iterator. + /// This likely consumes multiple random numbers, but the exact number + /// is unspecified. + fn choose(self, rng: &mut R) -> Option + where R: Rng + ?Sized + { + unimplemented!() + } + + /// Collects `amount` values at random from the iterator into a supplied + /// buffer. + /// + /// Returns the number of elements added to the buffer. This equals `amount` + /// unless the iterator contains insufficient elements, in which case this + /// equals the number of elements available. + /// + /// Complexity is TODO + fn choose_multiple_fill(self, rng: &mut R, amount: usize) -> usize + where R: Rng + ?Sized + { + unimplemented!() + } + + /// Collects `amount` values at random from the iterator into a vector. + /// + /// This is a convenience wrapper around `choose_multiple_fill`. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is TODO + #[cfg(any(feature="std", feature="alloc"))] + fn choose_multiple(self, rng: &mut R, amount: usize) -> Vec + where R: Rng + ?Sized + { + // Note: I think this must use unsafe to create an uninitialised buffer, then restrict length + unimplemented!() + } +} + + +impl SliceExt for [T] { + type Item = T; + + fn choose(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized + { + unimplemented!() + } + + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized + { + unimplemented!() + } + + fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> + where R: Rng + ?Sized + { + unimplemented!() + } + + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized + { + unimplemented!() + } + + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) + -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized + { + unimplemented!() + } +} + +// ——— +// TODO: remove below methods once implemented above +// TODO: also revise signature of `sample_indices`? +// ——— + /// Randomly sample `amount` elements from a finite iterator. /// /// The following can be returned: From e88075ece65a5de41401c3e611681404a6c24601 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sun, 27 May 2018 12:28:20 +0100 Subject: [PATCH 02/15] seq: enable module without std --- src/lib.rs | 2 +- src/seq.rs | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6d6d762b3cb..3458ba3207b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,7 +268,7 @@ pub mod distributions; pub mod prelude; pub mod prng; pub mod rngs; -#[cfg(feature = "alloc")] pub mod seq; +pub mod seq; //////////////////////////////////////////////////////////////////////////////// // Compatibility re-exports. Documentation is hidden; will be removed eventually. diff --git a/src/seq.rs b/src/seq.rs index b8ecf493c23..b3172dcc4be 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -14,12 +14,11 @@ use super::Rng; -// This crate is only enabled when either std or alloc is available. // BTreeMap is not as fast in tests, but better than nothing. #[cfg(feature="std")] use std::collections::HashMap; -#[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap; +#[cfg(all(feature="alloc", not(feature="std")))] use alloc::btree_map::BTreeMap; -#[cfg(not(feature="std"))] use alloc::Vec; +#[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; /// Extension trait on slices, providing random mutation and sampling methods. @@ -57,6 +56,7 @@ pub trait SliceExt { /// result, then apply to the slice. /// /// Complexity is expected to be the same as `sample_indices`. + #[cfg(feature = "alloc")] fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> where R: Rng + ?Sized; @@ -125,7 +125,7 @@ pub trait IteratorExt: Iterator + Sized { /// elements available. /// /// Complexity is TODO - #[cfg(any(feature="std", feature="alloc"))] + #[cfg(feature = "alloc")] fn choose_multiple(self, rng: &mut R, amount: usize) -> Vec where R: Rng + ?Sized { @@ -150,6 +150,7 @@ impl SliceExt for [T] { unimplemented!() } + #[cfg(feature = "alloc")] fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> where R: Rng + ?Sized { @@ -193,6 +194,7 @@ impl SliceExt for [T] { /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap(); /// println!("{:?}", sample); /// ``` +#[cfg(feature = "alloc")] pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result, Vec> where I: IntoIterator, R: Rng + ?Sized, @@ -238,6 +240,7 @@ pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result(rng: &mut R, slice: &[T], amount: usize) -> Vec where R: Rng + ?Sized, T: Clone @@ -266,6 +269,7 @@ pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec /// let values = vec![5, 6, 1, 3, 4, 6, 7]; /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3)); /// ``` +#[cfg(feature = "alloc")] pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized { @@ -286,6 +290,7 @@ pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> /// have the indices themselves so this is provided as an alternative. /// /// Panics if `amount > length` +#[cfg(feature = "alloc")] pub fn sample_indices(rng: &mut R, length: usize, amount: usize) -> Vec where R: Rng + ?Sized, { @@ -319,6 +324,7 @@ pub fn sample_indices(rng: &mut R, length: usize, amount: usize) -> Vec= length / 2` /// since it does not require allocating an extra cache and is much faster. +#[cfg(feature = "alloc")] fn sample_indices_inplace(rng: &mut R, length: usize, amount: usize) -> Vec where R: Rng + ?Sized, { @@ -340,6 +346,7 @@ fn sample_indices_inplace(rng: &mut R, length: usize, amount: usize) -> Vec( rng: &mut R, length: usize, @@ -379,11 +386,13 @@ fn sample_indices_cache( #[cfg(test)] mod test { use super::*; + #[cfg(feature = "alloc")] use {XorShiftRng, Rng, SeedableRng}; - #[cfg(not(feature="std"))] + #[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; #[test] + #[cfg(feature = "alloc")] fn test_sample_iter() { let min_val = 1; let max_val = 100; @@ -403,6 +412,7 @@ mod test { })); } #[test] + #[cfg(feature = "alloc")] fn test_sample_slice_boundaries() { let empty: &[u8] = &[]; @@ -447,6 +457,7 @@ mod test { } #[test] + #[cfg(feature = "alloc")] fn test_sample_slice() { let xor_rng = XorShiftRng::from_seed; From 8a07bb30d0f328869bddcc86e8e832e4a229ae88 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sun, 27 May 2018 12:22:15 +0100 Subject: [PATCH 03/15] Move choose, choose_mut and shuffle from Rng to SliceExt --- src/lib.rs | 102 ++++++++------------------------------------- src/rngs/thread.rs | 4 -- src/seq.rs | 77 ++++++++++++++++++++++++++++++++-- 3 files changed, 91 insertions(+), 92 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3458ba3207b..5156cbe89e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -563,64 +563,35 @@ pub trait Rng: RngCore { /// Return a random element from `values`. /// - /// Return `None` if `values` is empty. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let choices = [1, 2, 4, 8, 16, 32]; - /// let mut rng = thread_rng(); - /// println!("{:?}", rng.choose(&choices)); - /// assert_eq!(rng.choose(&choices[..0]), None); - /// ``` + /// Deprecated: use [`SliceExt::choose`] instead. + /// + /// [`SliceExt::choose`]: seq/trait.SliceExt.html#method.choose + #[deprecated(since="0.6.0", note="use SliceExt::choose instead")] fn choose<'a, T>(&mut self, values: &'a [T]) -> Option<&'a T> { - if values.is_empty() { - None - } else { - Some(&values[self.gen_range(0, values.len())]) - } + use seq::SliceExt; + values.choose(self) } /// Return a mutable pointer to a random element from `values`. /// - /// Return `None` if `values` is empty. + /// Deprecated: use [`SliceExt::choose_mut`] instead. + /// + /// [`SliceExt::choose_mut`]: seq/trait.SliceExt.html#method.choose_mut + #[deprecated(since="0.6.0", note="use SliceExt::choose_mut instead")] fn choose_mut<'a, T>(&mut self, values: &'a mut [T]) -> Option<&'a mut T> { - if values.is_empty() { - None - } else { - let len = values.len(); - Some(&mut values[self.gen_range(0, len)]) - } + use seq::SliceExt; + values.choose_mut(self) } /// Shuffle a mutable slice in place. /// - /// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle]( - /// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - /// which produces an unbiased permutation. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// let mut y = [1, 2, 3]; - /// rng.shuffle(&mut y); - /// println!("{:?}", y); - /// rng.shuffle(&mut y); - /// println!("{:?}", y); - /// ``` + /// Deprecated: use [`SliceExt::shuffle`] instead. + /// + /// [`SliceExt::shuffle`]: seq/trait.SliceExt.html#method.shuffle + #[deprecated(since="0.6.0", note="use SliceExt::shuffle instead")] fn shuffle(&mut self, values: &mut [T]) { - let mut i = values.len(); - while i >= 2 { - // invariant: elements with index >= i have been locked in place. - i -= 1; - // lock element i in place. - values.swap(i, self.gen_range(0, i + 1)); - } + use seq::SliceExt; + values.shuffle(self) } } @@ -973,35 +944,6 @@ mod test { } } - #[test] - fn test_choose() { - let mut r = rng(107); - assert_eq!(r.choose(&[1, 1, 1]).map(|&x|x), Some(1)); - - let v: &[isize] = &[]; - assert_eq!(r.choose(v), None); - } - - #[test] - fn test_shuffle() { - let mut r = rng(108); - let empty: &mut [isize] = &mut []; - r.shuffle(empty); - let mut one = [1]; - r.shuffle(&mut one); - let b: &[_] = &[1]; - assert_eq!(one, b); - - let mut two = [1, 2]; - r.shuffle(&mut two); - assert!(two == [1, 2] || two == [2, 1]); - - let mut x = [1, 1, 1]; - r.shuffle(&mut x); - let b: &[_] = &[1, 1, 1]; - assert_eq!(x, b); - } - #[test] fn test_rng_trait_object() { use distributions::{Distribution, Standard}; @@ -1009,10 +951,6 @@ mod test { let mut r = &mut rng as &mut RngCore; r.next_u32(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); let _c: u8 = Standard.sample(&mut r); } @@ -1025,10 +963,6 @@ mod test { let mut r = Box::new(rng) as Box; r.next_u32(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); let _c: u8 = Standard.sample(&mut r); } diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index 863b79d31c7..c6611db92ff 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -132,10 +132,6 @@ mod test { use Rng; let mut r = ::thread_rng(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); } } diff --git a/src/seq.rs b/src/seq.rs index b3172dcc4be..cad52f7d079 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -33,6 +33,18 @@ pub trait SliceExt { /// slice is empty. /// /// Depending on the implementation, complexity is expected to be `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::seq::SliceExt; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = thread_rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` fn choose(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized; @@ -60,9 +72,22 @@ pub trait SliceExt { fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> where R: Rng + ?Sized; - /// Shuffle a slice in place. + /// Shuffle a mutable slice in place. /// /// Depending on the implementation, complexity is expected to be `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::seq::SliceExt; + /// + /// let mut rng = thread_rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized; /// Shuffle a slice in place, but exit early. @@ -141,13 +166,22 @@ impl SliceExt for [T] { fn choose(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized { - unimplemented!() + if self.is_empty() { + None + } else { + Some(&self[rng.gen_range(0, self.len())]) + } } fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized { - unimplemented!() + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[rng.gen_range(0, len)]) + } } #[cfg(feature = "alloc")] @@ -159,7 +193,13 @@ impl SliceExt for [T] { fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized { - unimplemented!() + let mut i = self.len(); + while i >= 2 { + // invariant: elements with index >= i have been locked in place. + i -= 1; + // lock element i in place. + self.swap(i, rng.gen_range(0, i + 1)); + } } fn partial_shuffle(&mut self, rng: &mut R, amount: usize) @@ -391,6 +431,35 @@ mod test { #[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; + #[test] + fn test_choose() { + let mut r = ::test::rng(107); + assert_eq!([1, 1, 1].choose(&mut r).map(|&x|x), Some(1)); + + let v: &[isize] = &[]; + assert_eq!(v.choose(&mut r), None); + } + + #[test] + fn test_shuffle() { + let mut r = ::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + let mut x = [1, 1, 1]; + x.shuffle(&mut r); + let b: &[_] = &[1, 1, 1]; + assert_eq!(x, b); + } + #[test] #[cfg(feature = "alloc")] fn test_sample_iter() { From 3f5d562968edd39833c40623c9b86cf6cf5605c7 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sun, 27 May 2018 14:57:50 +0100 Subject: [PATCH 04/15] Impl new choose_multiple functions --- src/seq.rs | 165 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 108 insertions(+), 57 deletions(-) diff --git a/src/seq.rs b/src/seq.rs index cad52f7d079..f787111f3bd 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -12,14 +12,16 @@ //! //! TODO: module doc -use super::Rng; +#[cfg(feature="alloc")] use core::ops::Index; + +#[cfg(feature="std")] use std::vec; +#[cfg(all(feature="alloc", not(feature="std")))] use alloc::{vec, Vec}; // BTreeMap is not as fast in tests, but better than nothing. #[cfg(feature="std")] use std::collections::HashMap; #[cfg(all(feature="alloc", not(feature="std")))] use alloc::btree_map::BTreeMap; -#[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; - +use super::Rng; /// Extension trait on slices, providing random mutation and sampling methods. /// @@ -69,7 +71,7 @@ pub trait SliceExt { /// /// Complexity is expected to be the same as `sample_indices`. #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized; /// Shuffle a mutable slice in place. @@ -129,33 +131,78 @@ pub trait IteratorExt: Iterator + Sized { /// Collects `amount` values at random from the iterator into a supplied /// buffer. - /// + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// /// Returns the number of elements added to the buffer. This equals `amount` /// unless the iterator contains insufficient elements, in which case this /// equals the number of elements available. /// - /// Complexity is TODO - fn choose_multiple_fill(self, rng: &mut R, amount: usize) -> usize - where R: Rng + ?Sized + /// Complexity is `O(n)` where `n` is the length of the iterator. + fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) + -> usize where R: Rng + ?Sized { - unimplemented!() + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = rng.gen_range(0, i + 1 + amount); + if k < amount { + buf[k] = elem; + } + } + len } /// Collects `amount` values at random from the iterator into a vector. /// - /// This is a convenience wrapper around `choose_multiple_fill`. + /// This is equivalent to `choose_multiple_fill` except for the result type. /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// /// The length of the returned vector equals `amount` unless the iterator /// contains insufficient elements, in which case it equals the number of /// elements available. /// - /// Complexity is TODO + /// Complexity is `O(n)` where `n` is the length of the iterator. #[cfg(feature = "alloc")] - fn choose_multiple(self, rng: &mut R, amount: usize) -> Vec + fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec where R: Rng + ?Sized { - // Note: I think this must use unsafe to create an uninitialised buffer, then restrict length - unimplemented!() + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = rng.gen_range(0, i + 1 + amount); + if let Some(spot) = reservoir.get_mut(k) { + *spot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir } } @@ -185,10 +232,15 @@ impl SliceExt for [T] { } #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> Vec<&Self::Item> + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized { - unimplemented!() + let amount = ::core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: sample_indices(rng, self.len(), amount).into_iter(), + } } fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized @@ -209,57 +261,55 @@ impl SliceExt for [T] { } } +impl IteratorExt for I where I: Iterator + Sized {} + + +/// Iterator over multiple choices, as returned by [`SliceExt::choose_multiple]( +/// trait.SliceExt.html#method.choose_multiple). +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: ::core::marker::PhantomData, + indices: vec::IntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + self.indices.next().map(|i| &(*self.slice)[i]) + } + + fn size_hint(&self) -> (usize, Option) { + (self.indices.len(), Some(self.indices.len())) + } +} + + // ——— -// TODO: remove below methods once implemented above // TODO: also revise signature of `sample_indices`? // ——— /// Randomly sample `amount` elements from a finite iterator. /// -/// The following can be returned: -/// -/// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random. -/// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the -/// length of `iterable` was less than `amount`. This is considered an error since exactly -/// `amount` elements is typically expected. -/// -/// This implementation uses `O(len(iterable))` time and `O(amount)` memory. -/// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap(); -/// println!("{:?}", sample); -/// ``` +/// Deprecated: use [`IteratorExt::choose_multiple`] instead. +/// +/// [`IteratorExt::choose_multiple`]: trait.IteratorExt.html#method.choose_multiple #[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use IteratorExt::choose_multiple instead")] pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result, Vec> where I: IntoIterator, R: Rng + ?Sized, { - let mut iter = iterable.into_iter(); - let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(iter.by_ref().take(amount)); - - // Continue unless the iterator was exhausted - // - // note: this prevents iterators that "restart" from causing problems. - // If the iterator stops once, then so do we. - if reservoir.len() == amount { - for (i, elem) in iter.enumerate() { - let k = rng.gen_range(0, i + 1 + amount); - if let Some(spot) = reservoir.get_mut(k) { - *spot = elem; - } - } - Ok(reservoir) + use seq::IteratorExt; + let iter = iterable.into_iter(); + let result = iter.choose_multiple(rng, amount); + if result.len() == amount { + Ok(result) } else { - // Don't hang onto extra memory. There is a corner case where - // `amount` was much less than `len(iterable)`. - reservoir.shrink_to_fit(); - Err(reservoir) + Err(result) } } @@ -426,6 +476,7 @@ fn sample_indices_cache( #[cfg(test)] mod test { use super::*; + use super::IteratorExt; #[cfg(feature = "alloc")] use {XorShiftRng, Rng, SeedableRng}; #[cfg(all(feature="alloc", not(feature="std")))] @@ -468,8 +519,8 @@ mod test { let mut r = ::test::rng(401); let vals = (min_val..max_val).collect::>(); - let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap(); - let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err(); + let small_sample = vals.iter().choose_multiple(&mut r, 5); + let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); assert_eq!(small_sample.len(), 5); assert_eq!(large_sample.len(), vals.len()); From fe03aeddddd1569d4093fdbdbc3f149ba5faebc1 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sun, 27 May 2018 15:03:03 +0100 Subject: [PATCH 05/15] Implement IteratorExt::choose --- src/seq.rs | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/seq.rs b/src/seq.rs index f787111f3bd..b4751f25137 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -123,10 +123,24 @@ pub trait IteratorExt: Iterator + Sized { /// Complexity is `O(n)`, where `n` is the length of the iterator. /// This likely consumes multiple random numbers, but the exact number /// is unspecified. - fn choose(self, rng: &mut R) -> Option + fn choose(mut self, rng: &mut R) -> Option where R: Rng + ?Sized { - unimplemented!() + if let Some(elem) = self.next() { + let mut result = elem; + + // Continue until the iterator is exhausted + for (i, elem) in self.enumerate() { + let k = rng.gen_range(0, i + 2); + if k == 0 { + result = elem; + } + } + + Some(result) + } else { + None + } } /// Collects `amount` values at random from the iterator into a supplied @@ -485,7 +499,14 @@ mod test { #[test] fn test_choose() { let mut r = ::test::rng(107); - assert_eq!([1, 1, 1].choose(&mut r).map(|&x|x), Some(1)); + assert_eq!([1, 1, 1].choose(&mut r), Some(&1)); + + let mut v = [2]; + v.choose_mut(&mut r).map(|x| *x = 5); + assert_eq!(v[0], 5); + + let v = [3, 3, 3, 3]; + assert_eq!(v.iter().choose(&mut r), Some(&3)); let v: &[isize] = &[]; assert_eq!(v.choose(&mut r), None); From 27436e9f8ee887fca783b9bc2da051b1b93efb08 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sun, 27 May 2018 15:24:23 +0100 Subject: [PATCH 06/15] Implement SliceExt::partial_shuffle --- src/seq.rs | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/seq.rs b/src/seq.rs index b4751f25137..e74ea909c37 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -271,7 +271,23 @@ impl SliceExt for [T] { fn partial_shuffle(&mut self, rng: &mut R, amount: usize) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized { - unimplemented!() + // This applies Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation, but exits early after choosing `amount` + // elements. + + let len = self.len(); + let mut i = len; + let end = if amount >= len { 0 } else { len - amount }; + + while i > end { + // invariant: elements with index > i have been locked in place. + i -= 1; + // lock element i in place. + self.swap(i, rng.gen_range(0, i + 1)); + } + let r = self.split_at_mut(i); + (r.1, r.0) } } @@ -531,6 +547,22 @@ mod test { let b: &[_] = &[1, 1, 1]; assert_eq!(x, b); } + + #[test] + fn test_partial_shuffle() { + let mut r = ::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } #[test] #[cfg(feature = "alloc")] From 68d902e4320102f0c41092705a4973dfa76fd0c9 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 9 Jun 2018 10:01:09 +0100 Subject: [PATCH 07/15] Move seq benches to new file --- benches/misc.rs | 54 ------------------------------------------ benches/seq.rs | 62 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 54 deletions(-) create mode 100644 benches/seq.rs diff --git a/benches/misc.rs b/benches/misc.rs index 4e9cbda37ae..93a5c506fdc 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -8,7 +8,6 @@ const RAND_BENCH_N: u64 = 1000; use test::Bencher; use rand::prelude::*; -use rand::seq::*; #[bench] fn misc_gen_bool_const(b: &mut Bencher) { @@ -108,59 +107,6 @@ sample_binomial!(misc_binomial_100, 100, 0.99); sample_binomial!(misc_binomial_1000, 1000, 0.01); sample_binomial!(misc_binomial_1e12, 1000_000_000_000, 0.2); -#[bench] -fn misc_shuffle_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &mut [usize] = &mut [1; 100]; - b.iter(|| { - rng.shuffle(x); - x[0] - }) -} - -#[bench] -fn misc_sample_iter_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_iter(&mut rng, x, 10).unwrap_or_else(|e| e) - }) -} - -#[bench] -fn misc_sample_slice_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_slice(&mut rng, x, 10) - }) -} - -#[bench] -fn misc_sample_slice_ref_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_slice_ref(&mut rng, x, 10) - }) -} - -macro_rules! sample_indices { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| { - sample_indices(&mut rng, $length, $amount) - }) - } - } -} - -sample_indices!(misc_sample_indices_10_of_1k, 10, 1000); -sample_indices!(misc_sample_indices_50_of_1k, 50, 1000); -sample_indices!(misc_sample_indices_100_of_1k, 100, 1000); - #[bench] fn gen_1k_iter_repeat(b: &mut Bencher) { use std::iter; diff --git a/benches/seq.rs b/benches/seq.rs new file mode 100644 index 00000000000..ebcfd68c30f --- /dev/null +++ b/benches/seq.rs @@ -0,0 +1,62 @@ +#![feature(test)] + +extern crate test; +extern crate rand; + +use test::Bencher; + +use rand::prelude::*; +use rand::seq::*; + +#[bench] +fn misc_shuffle_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &mut [usize] = &mut [1; 100]; + b.iter(|| { + rng.shuffle(x); + x[0] + }) +} + +#[bench] +fn misc_sample_iter_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + sample_iter(&mut rng, x, 10).unwrap_or_else(|e| e) + }) +} + +#[bench] +fn misc_sample_slice_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + sample_slice(&mut rng, x, 10) + }) +} + +#[bench] +fn misc_sample_slice_ref_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + sample_slice_ref(&mut rng, x, 10) + }) +} + +macro_rules! sample_indices { + ($name:ident, $amount:expr, $length:expr) => { + #[bench] + fn $name(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + b.iter(|| { + sample_indices(&mut rng, $length, $amount) + }) + } + } +} + +sample_indices!(misc_sample_indices_10_of_1k, 10, 1000); +sample_indices!(misc_sample_indices_50_of_1k, 50, 1000); +sample_indices!(misc_sample_indices_100_of_1k, 100, 1000); From 1616f24f73855ea4b343b1804373d067e7035cfb Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 9 Jun 2018 10:09:07 +0100 Subject: [PATCH 08/15] Revise seq benches --- benches/seq.rs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/benches/seq.rs b/benches/seq.rs index ebcfd68c30f..7d3b4fd0299 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -9,39 +9,44 @@ use rand::prelude::*; use rand::seq::*; #[bench] -fn misc_shuffle_100(b: &mut Bencher) { +fn seq_shuffle_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &mut [usize] = &mut [1; 100]; b.iter(|| { - rng.shuffle(x); + x.shuffle(&mut rng); x[0] }) } #[bench] -fn misc_sample_iter_10_of_100(b: &mut Bencher) { +fn seq_slice_choose_multiple_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; + let mut buf = [0; 10]; b.iter(|| { - sample_iter(&mut rng, x, 10).unwrap_or_else(|e| e) + for (v, slot) in x.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + *slot = *v; + } + buf }) } #[bench] -fn misc_sample_slice_10_of_100(b: &mut Bencher) { +fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; b.iter(|| { - sample_slice(&mut rng, x, 10) + x.iter().cloned().choose_multiple(&mut rng, 10) /*.unwrap_or_else(|e| e)*/ }) } #[bench] -fn misc_sample_slice_ref_10_of_100(b: &mut Bencher) { +fn seq_iter_choose_multiple_fill_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; + let mut buf = [0; 10]; b.iter(|| { - sample_slice_ref(&mut rng, x, 10) + x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf) }) } @@ -57,6 +62,6 @@ macro_rules! sample_indices { } } -sample_indices!(misc_sample_indices_10_of_1k, 10, 1000); -sample_indices!(misc_sample_indices_50_of_1k, 50, 1000); -sample_indices!(misc_sample_indices_100_of_1k, 100, 1000); +sample_indices!(seq_sample_indices_10_of_1k, 10, 1000); +sample_indices!(seq_sample_indices_50_of_1k, 50, 1000); +sample_indices!(seq_sample_indices_100_of_1k, 100, 1000); From d8e9de151bab865580f9dca736bce35add2d7068 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 9 Jun 2018 10:33:38 +0100 Subject: [PATCH 09/15] Deprecate seq::sample_slice (and _ref) --- examples/monty-hall.rs | 4 ++-- src/seq.rs | 28 ++++++++++------------------ 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 3750f8fabe2..fad1b772343 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -63,8 +63,8 @@ fn simulate(random_door: &Uniform, rng: &mut R) // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - let choices = free_doors(&[car, choice]); - rand::seq::sample_slice(rng, &choices, 1)[0] + use rand::seq::SliceExt; + *free_doors(&[car, choice]).choose(rng).unwrap() } // Returns the door we switch to, given our current choice and diff --git a/src/seq.rs b/src/seq.rs index e74ea909c37..e244cd50fc2 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -351,16 +351,11 @@ pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result slice.len()` /// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let values = vec![5, 6, 1, 3, 4, 6, 7]; -/// println!("{:?}", seq::sample_slice(&mut rng, &values, 3)); -/// ``` +/// Deprecated: use [`SliceExt::choose_multiple`] instead. +/// +/// [`SliceExt::choose_multiple`]: trait.SliceExt.html#method.choose_multiple #[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use SliceExt::choose_multiple instead")] pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec where R: Rng + ?Sized, T: Clone @@ -380,16 +375,11 @@ pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec /// /// Panics if `amount > slice.len()` /// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let values = vec![5, 6, 1, 3, 4, 6, 7]; -/// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3)); -/// ``` +/// Deprecated: use [`SliceExt::choose_multiple`] instead. +/// +/// [`SliceExt::choose_multiple`]: trait.SliceExt.html#method.choose_multiple #[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use SliceExt::choose_multiple instead")] pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized { @@ -586,6 +576,7 @@ mod test { } #[test] #[cfg(feature = "alloc")] + #[allow(deprecated)] fn test_sample_slice_boundaries() { let empty: &[u8] = &[]; @@ -631,6 +622,7 @@ mod test { #[test] #[cfg(feature = "alloc")] + #[allow(deprecated)] fn test_sample_slice() { let xor_rng = XorShiftRng::from_seed; From 737df7573f5b4ce6c292f5c345031d4c7ecb5122 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 9 Jun 2018 10:34:42 +0100 Subject: [PATCH 10/15] Address review comments and doc example for choose_multiple --- src/seq.rs | 50 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/seq.rs b/src/seq.rs index e244cd50fc2..2c20c1fba5e 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -70,6 +70,23 @@ pub trait SliceExt { /// result, then apply to the slice. /// /// Complexity is expected to be the same as `sample_indices`. + /// + /// # Example + /// ``` + /// use rand::seq::SliceExt; + /// + /// let mut rng = &mut rand::thread_rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` #[cfg(feature = "alloc")] fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized; @@ -131,8 +148,8 @@ pub trait IteratorExt: Iterator + Sized { // Continue until the iterator is exhausted for (i, elem) in self.enumerate() { - let k = rng.gen_range(0, i + 2); - if k == 0 { + // TODO: benchmark using gen_ratio instead + if rng.gen_range(0, i + 2) == 0 { result = elem; } } @@ -198,7 +215,7 @@ pub trait IteratorExt: Iterator + Sized { where R: Rng + ?Sized { let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(self.by_ref().take(amount)); + reservoir.extend(self.by_ref().take(amount)); // Continue unless the iterator was exhausted // @@ -240,8 +257,7 @@ impl SliceExt for [T] { if self.is_empty() { None } else { - let len = self.len(); - Some(&mut self[rng.gen_range(0, len)]) + Some(&mut self[rng.gen_range(0, self.len())]) } } @@ -259,11 +275,8 @@ impl SliceExt for [T] { fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized { - let mut i = self.len(); - while i >= 2 { - // invariant: elements with index >= i have been locked in place. - i -= 1; - // lock element i in place. + for i in (1..self.len()).rev() { + // invariant: elements with index > i have been locked in place. self.swap(i, rng.gen_range(0, i + 1)); } } @@ -277,16 +290,13 @@ impl SliceExt for [T] { // elements. let len = self.len(); - let mut i = len; let end = if amount >= len { 0 } else { len - amount }; - while i > end { + for i in (end..len).rev() { // invariant: elements with index > i have been locked in place. - i -= 1; - // lock element i in place. self.swap(i, rng.gen_range(0, i + 1)); } - let r = self.split_at_mut(i); + let r = self.split_at_mut(end); (r.1, r.0) } } @@ -309,6 +319,7 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceCho type Item = &'a T; fn next(&mut self) -> Option { + // TODO: get_unchecked? self.indices.next().map(|i| &(*self.slice)[i]) } @@ -317,6 +328,15 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceCho } } +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + // ——— // TODO: also revise signature of `sample_indices`? From 263d97c9d306768d32abe91ae168dbc716cba295 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 12 Jun 2018 10:30:41 +0100 Subject: [PATCH 11/15] Optimise IteratorExt::choose --- benches/seq.rs | 9 +++++++++ src/seq.rs | 5 +++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/benches/seq.rs b/benches/seq.rs index 7d3b4fd0299..150ce87dc3e 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -31,6 +31,15 @@ fn seq_slice_choose_multiple_10_of_100(b: &mut Bencher) { }) } +#[bench] +fn seq_iter_choose_from_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + x.iter().cloned().choose(&mut rng) + }) +} + #[bench] fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); diff --git a/src/seq.rs b/src/seq.rs index 2c20c1fba5e..3df0b7b44a5 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -148,8 +148,9 @@ pub trait IteratorExt: Iterator + Sized { // Continue until the iterator is exhausted for (i, elem) in self.enumerate() { - // TODO: benchmark using gen_ratio instead - if rng.gen_range(0, i + 2) == 0 { + let denom = (i + 2) as u32; + assert_eq!(denom as usize, i + 2); // check against overflow + if rng.gen_ratio(1, denom) { result = elem; } } From ee52a640fcc29d1aa1866c8dfcb2ebd654393764 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 12 Jun 2018 10:42:43 +0100 Subject: [PATCH 12/15] Rename seq traits and export from prelude MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SliceExt → SliceRandom, IteratorExt → IteratorRandom --- examples/monty-hall.rs | 2 +- src/lib.rs | 24 ++++++++++++------------ src/prelude.rs | 1 + src/seq.rs | 40 ++++++++++++++++++++-------------------- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index fad1b772343..faf795cb873 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -63,7 +63,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - use rand::seq::SliceExt; + use rand::seq::SliceRandom; *free_doors(&[car, choice]).choose(rng).unwrap() } diff --git a/src/lib.rs b/src/lib.rs index 5156cbe89e9..3639f70fc89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -563,34 +563,34 @@ pub trait Rng: RngCore { /// Return a random element from `values`. /// - /// Deprecated: use [`SliceExt::choose`] instead. + /// Deprecated: use [`SliceRandom::choose`] instead. /// - /// [`SliceExt::choose`]: seq/trait.SliceExt.html#method.choose - #[deprecated(since="0.6.0", note="use SliceExt::choose instead")] + /// [`SliceRandom::choose`]: seq/trait.SliceRandom.html#method.choose + #[deprecated(since="0.6.0", note="use SliceRandom::choose instead")] fn choose<'a, T>(&mut self, values: &'a [T]) -> Option<&'a T> { - use seq::SliceExt; + use seq::SliceRandom; values.choose(self) } /// Return a mutable pointer to a random element from `values`. /// - /// Deprecated: use [`SliceExt::choose_mut`] instead. + /// Deprecated: use [`SliceRandom::choose_mut`] instead. /// - /// [`SliceExt::choose_mut`]: seq/trait.SliceExt.html#method.choose_mut - #[deprecated(since="0.6.0", note="use SliceExt::choose_mut instead")] + /// [`SliceRandom::choose_mut`]: seq/trait.SliceRandom.html#method.choose_mut + #[deprecated(since="0.6.0", note="use SliceRandom::choose_mut instead")] fn choose_mut<'a, T>(&mut self, values: &'a mut [T]) -> Option<&'a mut T> { - use seq::SliceExt; + use seq::SliceRandom; values.choose_mut(self) } /// Shuffle a mutable slice in place. /// - /// Deprecated: use [`SliceExt::shuffle`] instead. + /// Deprecated: use [`SliceRandom::shuffle`] instead. /// - /// [`SliceExt::shuffle`]: seq/trait.SliceExt.html#method.shuffle - #[deprecated(since="0.6.0", note="use SliceExt::shuffle instead")] + /// [`SliceRandom::shuffle`]: seq/trait.SliceRandom.html#method.shuffle + #[deprecated(since="0.6.0", note="use SliceRandom::shuffle instead")] fn shuffle(&mut self, values: &mut [T]) { - use seq::SliceExt; + use seq::SliceRandom; values.shuffle(self) } } diff --git a/src/prelude.rs b/src/prelude.rs index 358c2370823..ace5c89d700 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -26,3 +26,4 @@ #[doc(no_inline)] #[cfg(feature="std")] pub use rngs::ThreadRng; #[doc(no_inline)] pub use {Rng, RngCore, CryptoRng, SeedableRng}; #[doc(no_inline)] #[cfg(feature="std")] pub use {FromEntropy, random, thread_rng}; +#[doc(no_inline)] pub use seq::{SliceRandom, IteratorRandom}; diff --git a/src/seq.rs b/src/seq.rs index 3df0b7b44a5..d53ea9b24b5 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -27,7 +27,7 @@ use super::Rng; /// /// An implementation is provided for slices. This may also be implementable for /// other types. -pub trait SliceExt { +pub trait SliceRandom { /// The element type. type Item; @@ -40,7 +40,7 @@ pub trait SliceExt { /// /// ``` /// use rand::thread_rng; - /// use rand::seq::SliceExt; + /// use rand::seq::SliceRandom; /// /// let choices = [1, 2, 4, 8, 16, 32]; /// let mut rng = thread_rng(); @@ -73,7 +73,7 @@ pub trait SliceExt { /// /// # Example /// ``` - /// use rand::seq::SliceExt; + /// use rand::seq::SliceRandom; /// /// let mut rng = &mut rand::thread_rng(); /// let sample = "Hello, audience!".as_bytes(); @@ -99,7 +99,7 @@ pub trait SliceExt { /// /// ``` /// use rand::thread_rng; - /// use rand::seq::SliceExt; + /// use rand::seq::SliceRandom; /// /// let mut rng = thread_rng(); /// let mut y = [1, 2, 3, 4, 5]; @@ -132,7 +132,7 @@ pub trait SliceExt { } /// Extension trait on iterators, providing random sampling methods. -pub trait IteratorExt: Iterator + Sized { +pub trait IteratorRandom: Iterator + Sized { /// Choose one element at random from the iterator. /// /// Returns `None` if and only if the iterator is empty. @@ -239,7 +239,7 @@ pub trait IteratorExt: Iterator + Sized { } -impl SliceExt for [T] { +impl SliceRandom for [T] { type Item = T; fn choose(&self, rng: &mut R) -> Option<&Self::Item> @@ -302,11 +302,11 @@ impl SliceExt for [T] { } } -impl IteratorExt for I where I: Iterator + Sized {} +impl IteratorRandom for I where I: Iterator + Sized {} -/// Iterator over multiple choices, as returned by [`SliceExt::choose_multiple]( -/// trait.SliceExt.html#method.choose_multiple). +/// Iterator over multiple choices, as returned by [`SliceRandom::choose_multiple]( +/// trait.SliceRandom.html#method.choose_multiple). #[cfg(feature = "alloc")] #[derive(Debug)] pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { @@ -345,16 +345,16 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator /// Randomly sample `amount` elements from a finite iterator. /// -/// Deprecated: use [`IteratorExt::choose_multiple`] instead. +/// Deprecated: use [`IteratorRandom::choose_multiple`] instead. /// -/// [`IteratorExt::choose_multiple`]: trait.IteratorExt.html#method.choose_multiple +/// [`IteratorRandom::choose_multiple`]: trait.IteratorRandom.html#method.choose_multiple #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use IteratorExt::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use IteratorRandom::choose_multiple instead")] pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result, Vec> where I: IntoIterator, R: Rng + ?Sized, { - use seq::IteratorExt; + use seq::IteratorRandom; let iter = iterable.into_iter(); let result = iter.choose_multiple(rng, amount); if result.len() == amount { @@ -372,11 +372,11 @@ pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result slice.len()` /// -/// Deprecated: use [`SliceExt::choose_multiple`] instead. +/// Deprecated: use [`SliceRandom::choose_multiple`] instead. /// -/// [`SliceExt::choose_multiple`]: trait.SliceExt.html#method.choose_multiple +/// [`SliceRandom::choose_multiple`]: trait.SliceRandom.html#method.choose_multiple #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use SliceExt::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use SliceRandom::choose_multiple instead")] pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec where R: Rng + ?Sized, T: Clone @@ -396,11 +396,11 @@ pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec /// /// Panics if `amount > slice.len()` /// -/// Deprecated: use [`SliceExt::choose_multiple`] instead. +/// Deprecated: use [`SliceRandom::choose_multiple`] instead. /// -/// [`SliceExt::choose_multiple`]: trait.SliceExt.html#method.choose_multiple +/// [`SliceRandom::choose_multiple`]: trait.SliceRandom.html#method.choose_multiple #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use SliceExt::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use SliceRandom::choose_multiple instead")] pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized { @@ -517,7 +517,7 @@ fn sample_indices_cache( #[cfg(test)] mod test { use super::*; - use super::IteratorExt; + use super::IteratorRandom; #[cfg(feature = "alloc")] use {XorShiftRng, Rng, SeedableRng}; #[cfg(all(feature="alloc", not(feature="std")))] From 83130e845bc6bfd9ac76143ead647efb66dec0b3 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 12 Jun 2018 14:02:24 +0100 Subject: [PATCH 13/15] Fix build for Rustc 1.22 --- src/seq.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/seq.rs b/src/seq.rs index d53ea9b24b5..cefd2aeb2e7 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -258,7 +258,8 @@ impl SliceRandom for [T] { if self.is_empty() { None } else { - Some(&mut self[rng.gen_range(0, self.len())]) + let len = self.len(); + Some(&mut self[rng.gen_range(0, len)]) } } From 27e323ac8f693106bb9cad96a5bfe949dd9fd37f Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 12 Jun 2018 14:05:29 +0100 Subject: [PATCH 14/15] =?UTF-8?q?Rename=20choose=5Fmultiple=20=E2=86=92=20?= =?UTF-8?q?sample?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benches/seq.rs | 12 ++++++------ src/seq.rs | 44 ++++++++++++++++++++++---------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/benches/seq.rs b/benches/seq.rs index 150ce87dc3e..3e5be54375d 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -19,12 +19,12 @@ fn seq_shuffle_100(b: &mut Bencher) { } #[bench] -fn seq_slice_choose_multiple_10_of_100(b: &mut Bencher) { +fn seq_slice_sample_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; let mut buf = [0; 10]; b.iter(|| { - for (v, slot) in x.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + for (v, slot) in x.sample(&mut rng, buf.len()).zip(buf.iter_mut()) { *slot = *v; } buf @@ -41,21 +41,21 @@ fn seq_iter_choose_from_100(b: &mut Bencher) { } #[bench] -fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { +fn seq_iter_sample_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; b.iter(|| { - x.iter().cloned().choose_multiple(&mut rng, 10) /*.unwrap_or_else(|e| e)*/ + x.iter().cloned().sample(&mut rng, 10) /*.unwrap_or_else(|e| e)*/ }) } #[bench] -fn seq_iter_choose_multiple_fill_10_of_100(b: &mut Bencher) { +fn seq_iter_sample_fill_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; let mut buf = [0; 10]; b.iter(|| { - x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf) + x.iter().cloned().sample_fill(&mut rng, &mut buf) }) } diff --git a/src/seq.rs b/src/seq.rs index cefd2aeb2e7..67e113f48a3 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -76,19 +76,19 @@ pub trait SliceRandom { /// use rand::seq::SliceRandom; /// /// let mut rng = &mut rand::thread_rng(); - /// let sample = "Hello, audience!".as_bytes(); + /// let sequence = "Hello, audience!".as_bytes(); /// /// // collect the results into a vector: - /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// let v: Vec = sequence.sample(&mut rng, 3).cloned().collect(); /// /// // store in a buffer: /// let mut buf = [0u8; 5]; - /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// for (b, slot) in sequence.sample(&mut rng, buf.len()).zip(buf.iter_mut()) { /// *slot = *b; /// } /// ``` #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + fn sample(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized; /// Shuffle a mutable slice in place. @@ -173,7 +173,7 @@ pub trait IteratorRandom: Iterator + Sized { /// equals the number of elements available. /// /// Complexity is `O(n)` where `n` is the length of the iterator. - fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) + fn sample_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize where R: Rng + ?Sized { let amount = buf.len(); @@ -200,7 +200,7 @@ pub trait IteratorRandom: Iterator + Sized { /// Collects `amount` values at random from the iterator into a vector. /// - /// This is equivalent to `choose_multiple_fill` except for the result type. + /// This is equivalent to `sample_fill` except for the result type. /// /// Although the elements are selected randomly, the order of elements in /// the buffer is neither stable nor fully random. If random ordering is @@ -212,7 +212,7 @@ pub trait IteratorRandom: Iterator + Sized { /// /// Complexity is `O(n)` where `n` is the length of the iterator. #[cfg(feature = "alloc")] - fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec + fn sample(mut self, rng: &mut R, amount: usize) -> Vec where R: Rng + ?Sized { let mut reservoir = Vec::with_capacity(amount); @@ -264,7 +264,7 @@ impl SliceRandom for [T] { } #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + fn sample(&self, rng: &mut R, amount: usize) -> SliceChooseIter where R: Rng + ?Sized { let amount = ::core::cmp::min(amount, self.len()); @@ -306,8 +306,8 @@ impl SliceRandom for [T] { impl IteratorRandom for I where I: Iterator + Sized {} -/// Iterator over multiple choices, as returned by [`SliceRandom::choose_multiple]( -/// trait.SliceRandom.html#method.choose_multiple). +/// Iterator over multiple choices, as returned by [`SliceRandom::sample]( +/// trait.SliceRandom.html#method.sample). #[cfg(feature = "alloc")] #[derive(Debug)] pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { @@ -346,18 +346,18 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator /// Randomly sample `amount` elements from a finite iterator. /// -/// Deprecated: use [`IteratorRandom::choose_multiple`] instead. +/// Deprecated: use [`IteratorRandom::sample`] instead. /// -/// [`IteratorRandom::choose_multiple`]: trait.IteratorRandom.html#method.choose_multiple +/// [`IteratorRandom::sample`]: trait.IteratorRandom.html#method.sample #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use IteratorRandom::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use IteratorRandom::sample instead")] pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result, Vec> where I: IntoIterator, R: Rng + ?Sized, { use seq::IteratorRandom; let iter = iterable.into_iter(); - let result = iter.choose_multiple(rng, amount); + let result = iter.sample(rng, amount); if result.len() == amount { Ok(result) } else { @@ -373,11 +373,11 @@ pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result slice.len()` /// -/// Deprecated: use [`SliceRandom::choose_multiple`] instead. +/// Deprecated: use [`SliceRandom::sample`] instead. /// -/// [`SliceRandom::choose_multiple`]: trait.SliceRandom.html#method.choose_multiple +/// [`SliceRandom::sample`]: trait.SliceRandom.html#method.sample #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use SliceRandom::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use SliceRandom::sample instead")] pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec where R: Rng + ?Sized, T: Clone @@ -397,11 +397,11 @@ pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec /// /// Panics if `amount > slice.len()` /// -/// Deprecated: use [`SliceRandom::choose_multiple`] instead. +/// Deprecated: use [`SliceRandom::sample`] instead. /// -/// [`SliceRandom::choose_multiple`]: trait.SliceRandom.html#method.choose_multiple +/// [`SliceRandom::sample`]: trait.SliceRandom.html#method.sample #[cfg(feature = "alloc")] -#[deprecated(since="0.6.0", note="use SliceRandom::choose_multiple instead")] +#[deprecated(since="0.6.0", note="use SliceRandom::sample instead")] pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized { @@ -584,8 +584,8 @@ mod test { let mut r = ::test::rng(401); let vals = (min_val..max_val).collect::>(); - let small_sample = vals.iter().choose_multiple(&mut r, 5); - let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); + let small_sample = vals.iter().sample(&mut r, 5); + let large_sample = vals.iter().sample(&mut r, vals.len() + 5); assert_eq!(small_sample.len(), 5); assert_eq!(large_sample.len(), vals.len()); From 15f91f56351b866ae46513378439ab3677e7ba6c Mon Sep 17 00:00:00 2001 From: Jonas Sicking Date: Tue, 12 Jun 2018 15:29:33 -0700 Subject: [PATCH 15/15] Implement choose_weighted and choose_weighted_with_total --- src/seq.rs | 234 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) diff --git a/src/seq.rs b/src/seq.rs index 67e113f48a3..1a3a448d0d3 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -22,6 +22,7 @@ #[cfg(all(feature="alloc", not(feature="std")))] use alloc::btree_map::BTreeMap; use super::Rng; +use distributions::uniform::SampleUniform; /// Extension trait on slices, providing random mutation and sampling methods. /// @@ -515,6 +516,133 @@ fn sample_indices_cache( out } +/// Return a random element from `items` where the chance of a given item +/// being picked is proportional to the corresponding value in `weights`. +/// `weights` and `items` must return exactly the same number of values. +/// +/// All values returned by `weights` must be `>= 0`. +/// +/// This function iterates over `weights` twice. Once to get the total +/// weight, and once while choosing the random value. If you know the total +/// weight, or plan to call this function multiple times, you should +/// consider using [`choose_weighted_with_total`] instead. +/// +/// Return `None` if `items` and `weights` is empty. +/// +/// # Panics +/// +/// If a value in `weights < 0`. +/// If `weights` and `items` return different number of items. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rand::seq::choose_weighted; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let mut rng = thread_rng(); +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choose_weighted(&mut rng, choices.iter(), weights.iter().cloned()).unwrap()); +/// ``` +/// [`choose_weighted_with_total`]: seq/fn.choose_weighted_with_total.html +pub fn choose_weighted(rng: &mut R, + items: IterItems, + weights: IterWeights) -> Option + where R: Rng + ?Sized, + IterItems: Iterator, + IterWeights: Iterator + Clone, + IterWeights::Item: SampleUniform + + Default + + ::core::ops::Add + + ::core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + let total_weight: IterWeights::Item = + weights.clone().fold(Default::default(), |acc, w| { + assert!(w >= Default::default(), "Weight must be larger than zero"); + acc + w + }); + choose_weighted_with_total(rng, items, weights, total_weight) +} + +/// Return a random element from `items` where. The chance of a given item +/// being picked, is proportional to the corresponding value in `weights`. +/// `weights` and `items` must return exactly the same number of values. +/// +/// All values returned by `weights` must be `>= 0`. +/// +/// `total_weight` must be exactly the sum of all values returned by +/// `weights`. Builds with debug_assertions turned on will assert that this +/// equality holds. Simply storing the result of `weights.sum()` and using +/// that as `total_weight` should work. +/// +/// Return `None` if `items` and `weights` is empty. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rand::seq::choose_weighted_with_total; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let mut rng = thread_rng(); +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choose_weighted_with_total(&mut rng, choices.iter(), weights.iter().cloned(), 4).unwrap()); +/// ``` +pub fn choose_weighted_with_total(rng: &mut R, + mut items: IterItems, + mut weights: IterWeights, + total_weight: IterWeights::Item) -> Option + where R: Rng + ?Sized, + IterItems: Iterator, + IterWeights: Iterator, + IterWeights::Item: SampleUniform + + Default + + ::core::ops::Add + + ::core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + + if total_weight == Default::default() { + debug_assert!(items.next().is_none()); + return None; + } + + // Only used when debug_assertions are turned on + let mut debug_result = None; + let debug_total_weight = if cfg!(debug_assertions) { Some(total_weight.clone()) } else { None }; + + let chosen_weight = rng.gen_range(Default::default(), total_weight); + let mut cumulative_weight: IterWeights::Item = Default::default(); + + for item in items { + let weight_opt = weights.next(); + assert!(weight_opt.is_some(), "`weights` returned fewer items than `items` did"); + let weight = weight_opt.unwrap(); + assert!(weight >= Default::default(), "Weight must be larger than zero"); + + cumulative_weight = cumulative_weight + weight; + + if cumulative_weight > chosen_weight { + if !cfg!(debug_assertions) { + return Some(item); + } + if debug_result.is_none() { + debug_result = Some(item); + } + } + } + + assert!(weights.next().is_none(), "`weights` returned more items than `items` did"); + debug_assert!(debug_total_weight.unwrap() == cumulative_weight); + if cfg!(debug_assertions) && debug_result.is_some() { + return debug_result; + } + + panic!("total_weight did not match up with sum of weights"); +} + #[cfg(test)] mod test { use super::*; @@ -523,6 +651,8 @@ mod test { use {XorShiftRng, Rng, SeedableRng}; #[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; + #[cfg(feature="std")] + use core::panic::catch_unwind; #[test] fn test_choose() { @@ -684,4 +814,108 @@ mod test { } } } + + #[test] + fn test_choose_weighted() { + let mut r = ::test::rng(406); + let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum(); + assert_eq!(chars.len(), weights.len()); + + let mut chosen = [0i32; 14]; + for _ in 0..1000 { + let picked = *choose_weighted(&mut r, + chars.iter(), + weights.iter().cloned()).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // Mutable items + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *choose_weighted(&mut r, + chosen.iter_mut(), + weights.iter().cloned()).unwrap() += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // choose_weighted_with_total + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + let picked = *choose_weighted_with_total(&mut r, + chars.iter(), + weights.iter().cloned(), + total_weight).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + } + + #[test] + #[cfg(all(feature="std", + not(target_arch = "wasm32"), + not(target_arch = "asmjs")))] + fn test_choose_weighted_assertions() { + fn inner_delta(delta: i32) { + let items = vec![1, 2, 3]; + let mut r = ::test::rng(407); + if cfg!(debug_assertions) || delta == 0 { + choose_weighted_with_total(&mut r, + items.iter(), + items.iter().cloned(), + 6+delta); + } else { + loop { + choose_weighted_with_total(&mut r, + items.iter(), + items.iter().cloned(), + 6+delta); + } + } + } + + assert!(catch_unwind(|| inner_delta(0)).is_ok()); + assert!(catch_unwind(|| inner_delta(1)).is_err()); + assert!(catch_unwind(|| inner_delta(1000)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too small total_weight + assert!(catch_unwind(|| inner_delta(-1)).is_err()); + assert!(catch_unwind(|| inner_delta(-1000)).is_err()); + } + + fn inner_size(items: usize, weights: usize, with_total: bool) { + let mut r = ::test::rng(408); + if with_total { + choose_weighted_with_total(&mut r, + ::core::iter::repeat(1usize).take(items), + ::core::iter::repeat(1usize).take(weights), + weights); + } else { + choose_weighted(&mut r, + ::core::iter::repeat(1usize).take(items), + ::core::iter::repeat(1usize).take(weights)); + } + } + + assert!(catch_unwind(|| inner_size(2, 2, true)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 2, false)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 1, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 1, false)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too many weights + assert!(catch_unwind(|| inner_size(2, 3, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 3, false)).is_err()); + } + } }