diff --git a/benches/misc.rs b/benches/misc.rs index 4e9cbda37ae..93a5c506fdc 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -8,7 +8,6 @@ const RAND_BENCH_N: u64 = 1000; use test::Bencher; use rand::prelude::*; -use rand::seq::*; #[bench] fn misc_gen_bool_const(b: &mut Bencher) { @@ -108,59 +107,6 @@ sample_binomial!(misc_binomial_100, 100, 0.99); sample_binomial!(misc_binomial_1000, 1000, 0.01); sample_binomial!(misc_binomial_1e12, 1000_000_000_000, 0.2); -#[bench] -fn misc_shuffle_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &mut [usize] = &mut [1; 100]; - b.iter(|| { - rng.shuffle(x); - x[0] - }) -} - -#[bench] -fn misc_sample_iter_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_iter(&mut rng, x, 10).unwrap_or_else(|e| e) - }) -} - -#[bench] -fn misc_sample_slice_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_slice(&mut rng, x, 10) - }) -} - -#[bench] -fn misc_sample_slice_ref_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 100]; - b.iter(|| { - sample_slice_ref(&mut rng, x, 10) - }) -} - -macro_rules! sample_indices { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| { - sample_indices(&mut rng, $length, $amount) - }) - } - } -} - -sample_indices!(misc_sample_indices_10_of_1k, 10, 1000); -sample_indices!(misc_sample_indices_50_of_1k, 50, 1000); -sample_indices!(misc_sample_indices_100_of_1k, 100, 1000); - #[bench] fn gen_1k_iter_repeat(b: &mut Bencher) { use std::iter; diff --git a/benches/seq.rs b/benches/seq.rs new file mode 100644 index 00000000000..3e5be54375d --- /dev/null +++ b/benches/seq.rs @@ -0,0 +1,76 @@ +#![feature(test)] + +extern crate test; +extern crate rand; + +use test::Bencher; + +use rand::prelude::*; +use rand::seq::*; + +#[bench] +fn seq_shuffle_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &mut [usize] = &mut [1; 100]; + b.iter(|| { + x.shuffle(&mut rng); + x[0] + }) +} + +#[bench] +fn seq_slice_sample_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + let mut buf = [0; 10]; + b.iter(|| { + for (v, slot) in x.sample(&mut rng, buf.len()).zip(buf.iter_mut()) { + *slot = *v; + } + buf + }) +} + +#[bench] +fn seq_iter_choose_from_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + x.iter().cloned().choose(&mut rng) + }) +} + +#[bench] +fn seq_iter_sample_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + b.iter(|| { + x.iter().cloned().sample(&mut rng, 10) /*.unwrap_or_else(|e| e)*/ + }) +} + +#[bench] +fn seq_iter_sample_fill_10_of_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 100]; + let mut buf = [0; 10]; + b.iter(|| { + x.iter().cloned().sample_fill(&mut rng, &mut buf) + }) +} + +macro_rules! sample_indices { + ($name:ident, $amount:expr, $length:expr) => { + #[bench] + fn $name(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + b.iter(|| { + sample_indices(&mut rng, $length, $amount) + }) + } + } +} + +sample_indices!(seq_sample_indices_10_of_1k, 10, 1000); +sample_indices!(seq_sample_indices_50_of_1k, 50, 1000); +sample_indices!(seq_sample_indices_100_of_1k, 100, 1000); diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 3750f8fabe2..faf795cb873 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -63,8 +63,8 @@ fn simulate(random_door: &Uniform, rng: &mut R) // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - let choices = free_doors(&[car, choice]); - rand::seq::sample_slice(rng, &choices, 1)[0] + use rand::seq::SliceRandom; + *free_doors(&[car, choice]).choose(rng).unwrap() } // Returns the door we switch to, given our current choice and diff --git a/src/lib.rs b/src/lib.rs index 6d6d762b3cb..3639f70fc89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,7 +268,7 @@ pub mod distributions; pub mod prelude; pub mod prng; pub mod rngs; -#[cfg(feature = "alloc")] pub mod seq; +pub mod seq; //////////////////////////////////////////////////////////////////////////////// // Compatibility re-exports. Documentation is hidden; will be removed eventually. @@ -563,64 +563,35 @@ pub trait Rng: RngCore { /// Return a random element from `values`. /// - /// Return `None` if `values` is empty. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let choices = [1, 2, 4, 8, 16, 32]; - /// let mut rng = thread_rng(); - /// println!("{:?}", rng.choose(&choices)); - /// assert_eq!(rng.choose(&choices[..0]), None); - /// ``` + /// Deprecated: use [`SliceRandom::choose`] instead. + /// + /// [`SliceRandom::choose`]: seq/trait.SliceRandom.html#method.choose + #[deprecated(since="0.6.0", note="use SliceRandom::choose instead")] fn choose<'a, T>(&mut self, values: &'a [T]) -> Option<&'a T> { - if values.is_empty() { - None - } else { - Some(&values[self.gen_range(0, values.len())]) - } + use seq::SliceRandom; + values.choose(self) } /// Return a mutable pointer to a random element from `values`. /// - /// Return `None` if `values` is empty. + /// Deprecated: use [`SliceRandom::choose_mut`] instead. + /// + /// [`SliceRandom::choose_mut`]: seq/trait.SliceRandom.html#method.choose_mut + #[deprecated(since="0.6.0", note="use SliceRandom::choose_mut instead")] fn choose_mut<'a, T>(&mut self, values: &'a mut [T]) -> Option<&'a mut T> { - if values.is_empty() { - None - } else { - let len = values.len(); - Some(&mut values[self.gen_range(0, len)]) - } + use seq::SliceRandom; + values.choose_mut(self) } /// Shuffle a mutable slice in place. /// - /// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle]( - /// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - /// which produces an unbiased permutation. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// let mut y = [1, 2, 3]; - /// rng.shuffle(&mut y); - /// println!("{:?}", y); - /// rng.shuffle(&mut y); - /// println!("{:?}", y); - /// ``` + /// Deprecated: use [`SliceRandom::shuffle`] instead. + /// + /// [`SliceRandom::shuffle`]: seq/trait.SliceRandom.html#method.shuffle + #[deprecated(since="0.6.0", note="use SliceRandom::shuffle instead")] fn shuffle(&mut self, values: &mut [T]) { - let mut i = values.len(); - while i >= 2 { - // invariant: elements with index >= i have been locked in place. - i -= 1; - // lock element i in place. - values.swap(i, self.gen_range(0, i + 1)); - } + use seq::SliceRandom; + values.shuffle(self) } } @@ -973,35 +944,6 @@ mod test { } } - #[test] - fn test_choose() { - let mut r = rng(107); - assert_eq!(r.choose(&[1, 1, 1]).map(|&x|x), Some(1)); - - let v: &[isize] = &[]; - assert_eq!(r.choose(v), None); - } - - #[test] - fn test_shuffle() { - let mut r = rng(108); - let empty: &mut [isize] = &mut []; - r.shuffle(empty); - let mut one = [1]; - r.shuffle(&mut one); - let b: &[_] = &[1]; - assert_eq!(one, b); - - let mut two = [1, 2]; - r.shuffle(&mut two); - assert!(two == [1, 2] || two == [2, 1]); - - let mut x = [1, 1, 1]; - r.shuffle(&mut x); - let b: &[_] = &[1, 1, 1]; - assert_eq!(x, b); - } - #[test] fn test_rng_trait_object() { use distributions::{Distribution, Standard}; @@ -1009,10 +951,6 @@ mod test { let mut r = &mut rng as &mut RngCore; r.next_u32(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); let _c: u8 = Standard.sample(&mut r); } @@ -1025,10 +963,6 @@ mod test { let mut r = Box::new(rng) as Box; r.next_u32(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); let _c: u8 = Standard.sample(&mut r); } diff --git a/src/prelude.rs b/src/prelude.rs index 358c2370823..ace5c89d700 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -26,3 +26,4 @@ #[doc(no_inline)] #[cfg(feature="std")] pub use rngs::ThreadRng; #[doc(no_inline)] pub use {Rng, RngCore, CryptoRng, SeedableRng}; #[doc(no_inline)] #[cfg(feature="std")] pub use {FromEntropy, random, thread_rng}; +#[doc(no_inline)] pub use seq::{SliceRandom, IteratorRandom}; diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index 863b79d31c7..c6611db92ff 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -132,10 +132,6 @@ mod test { use Rng; let mut r = ::thread_rng(); r.gen::(); - let mut v = [1, 1, 1]; - r.shuffle(&mut v); - let b: &[_] = &[1, 1, 1]; - assert_eq!(v, b); assert_eq!(r.gen_range(0, 1), 0); } } diff --git a/src/seq.rs b/src/seq.rs index 68f7ab08edc..1a3a448d0d3 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -9,61 +9,360 @@ // except according to those terms. //! Functions for randomly accessing and sampling sequences. +//! +//! TODO: module doc -use super::Rng; +#[cfg(feature="alloc")] use core::ops::Index; + +#[cfg(feature="std")] use std::vec; +#[cfg(all(feature="alloc", not(feature="std")))] use alloc::{vec, Vec}; -// This crate is only enabled when either std or alloc is available. // BTreeMap is not as fast in tests, but better than nothing. #[cfg(feature="std")] use std::collections::HashMap; -#[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap; +#[cfg(all(feature="alloc", not(feature="std")))] use alloc::btree_map::BTreeMap; + +use super::Rng; +use distributions::uniform::SampleUniform; + +/// Extension trait on slices, providing random mutation and sampling methods. +/// +/// An implementation is provided for slices. This may also be implementable for +/// other types. +pub trait SliceRandom { + /// The element type. + type Item; + + /// Returns a reference to one random element of the slice, or `None` if the + /// slice is empty. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::seq::SliceRandom; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = thread_rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` + fn choose(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized; + + /// Returns a mutable reference to one random element of the slice, or + /// `None` if the slice is empty. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized; + + /// Produces an iterator that chooses `amount` elements from the slice at + /// random without repeating any. + /// + /// In case this API is not sufficiently flexible, use `sample_indices` then + /// apply the indices to the slice. + /// + /// Although the elements are selected randomly, the order of returned + /// elements is neither stable nor fully random. If random ordering is + /// desired, either use `partial_shuffle` or use this method and shuffle + /// the result. If stable order is desired, use `sample_indices`, sort the + /// result, then apply to the slice. + /// + /// Complexity is expected to be the same as `sample_indices`. + /// + /// # Example + /// ``` + /// use rand::seq::SliceRandom; + /// + /// let mut rng = &mut rand::thread_rng(); + /// let sequence = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec = sequence.sample(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sequence.sample(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` + #[cfg(feature = "alloc")] + fn sample(&self, rng: &mut R, amount: usize) -> SliceChooseIter + where R: Rng + ?Sized; + + /// Shuffle a mutable slice in place. + /// + /// Depending on the implementation, complexity is expected to be `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::seq::SliceRandom; + /// + /// let mut rng = thread_rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to chose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = values.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// Depending on the implementation, complexity is expected to be `O(m)`, + /// where `m = amount`. + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) + -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized; +} + +/// Extension trait on iterators, providing random sampling methods. +pub trait IteratorRandom: Iterator + Sized { + /// Choose one element at random from the iterator. + /// + /// Returns `None` if and only if the iterator is empty. + /// + /// Complexity is `O(n)`, where `n` is the length of the iterator. + /// This likely consumes multiple random numbers, but the exact number + /// is unspecified. + fn choose(mut self, rng: &mut R) -> Option + where R: Rng + ?Sized + { + if let Some(elem) = self.next() { + let mut result = elem; + + // Continue until the iterator is exhausted + for (i, elem) in self.enumerate() { + let denom = (i + 2) as u32; + assert_eq!(denom as usize, i + 2); // check against overflow + if rng.gen_ratio(1, denom) { + result = elem; + } + } + + Some(result) + } else { + None + } + } + + /// Collects `amount` values at random from the iterator into a supplied + /// buffer. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// Returns the number of elements added to the buffer. This equals `amount` + /// unless the iterator contains insufficient elements, in which case this + /// equals the number of elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + fn sample_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) + -> usize where R: Rng + ?Sized + { + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = rng.gen_range(0, i + 1 + amount); + if k < amount { + buf[k] = elem; + } + } + len + } + + /// Collects `amount` values at random from the iterator into a vector. + /// + /// This is equivalent to `sample_fill` except for the result type. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + #[cfg(feature = "alloc")] + fn sample(mut self, rng: &mut R, amount: usize) -> Vec + where R: Rng + ?Sized + { + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = rng.gen_range(0, i + 1 + amount); + if let Some(spot) = reservoir.get_mut(k) { + *spot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir + } +} + + +impl SliceRandom for [T] { + type Item = T; + + fn choose(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized + { + if self.is_empty() { + None + } else { + Some(&self[rng.gen_range(0, self.len())]) + } + } + + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized + { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[rng.gen_range(0, len)]) + } + } + + #[cfg(feature = "alloc")] + fn sample(&self, rng: &mut R, amount: usize) -> SliceChooseIter + where R: Rng + ?Sized + { + let amount = ::core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: sample_indices(rng, self.len(), amount).into_iter(), + } + } + + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized + { + for i in (1..self.len()).rev() { + // invariant: elements with index > i have been locked in place. + self.swap(i, rng.gen_range(0, i + 1)); + } + } + + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) + -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized + { + // This applies Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation, but exits early after choosing `amount` + // elements. + + let len = self.len(); + let end = if amount >= len { 0 } else { len - amount }; + + for i in (end..len).rev() { + // invariant: elements with index > i have been locked in place. + self.swap(i, rng.gen_range(0, i + 1)); + } + let r = self.split_at_mut(end); + (r.1, r.0) + } +} + +impl IteratorRandom for I where I: Iterator + Sized {} + + +/// Iterator over multiple choices, as returned by [`SliceRandom::sample]( +/// trait.SliceRandom.html#method.sample). +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: ::core::marker::PhantomData, + indices: vec::IntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; -#[cfg(not(feature="std"))] use alloc::Vec; + fn next(&mut self) -> Option { + // TODO: get_unchecked? + self.indices.next().map(|i| &(*self.slice)[i]) + } + + fn size_hint(&self) -> (usize, Option) { + (self.indices.len(), Some(self.indices.len())) + } +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + + +// ——— +// TODO: also revise signature of `sample_indices`? +// ——— /// Randomly sample `amount` elements from a finite iterator. /// -/// The following can be returned: -/// -/// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random. -/// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the -/// length of `iterable` was less than `amount`. This is considered an error since exactly -/// `amount` elements is typically expected. -/// -/// This implementation uses `O(len(iterable))` time and `O(amount)` memory. -/// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap(); -/// println!("{:?}", sample); -/// ``` +/// Deprecated: use [`IteratorRandom::sample`] instead. +/// +/// [`IteratorRandom::sample`]: trait.IteratorRandom.html#method.sample +#[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use IteratorRandom::sample instead")] pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result, Vec> where I: IntoIterator, R: Rng + ?Sized, { - let mut iter = iterable.into_iter(); - let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(iter.by_ref().take(amount)); - - // Continue unless the iterator was exhausted - // - // note: this prevents iterators that "restart" from causing problems. - // If the iterator stops once, then so do we. - if reservoir.len() == amount { - for (i, elem) in iter.enumerate() { - let k = rng.gen_range(0, i + 1 + amount); - if let Some(spot) = reservoir.get_mut(k) { - *spot = elem; - } - } - Ok(reservoir) + use seq::IteratorRandom; + let iter = iterable.into_iter(); + let result = iter.sample(rng, amount); + if result.len() == amount { + Ok(result) } else { - // Don't hang onto extra memory. There is a corner case where - // `amount` was much less than `len(iterable)`. - reservoir.shrink_to_fit(); - Err(reservoir) + Err(result) } } @@ -75,15 +374,11 @@ pub fn sample_iter(rng: &mut R, iterable: I, amount: usize) -> Result slice.len()` /// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let values = vec![5, 6, 1, 3, 4, 6, 7]; -/// println!("{:?}", seq::sample_slice(&mut rng, &values, 3)); -/// ``` +/// Deprecated: use [`SliceRandom::sample`] instead. +/// +/// [`SliceRandom::sample`]: trait.SliceRandom.html#method.sample +#[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use SliceRandom::sample instead")] pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec where R: Rng + ?Sized, T: Clone @@ -103,15 +398,11 @@ pub fn sample_slice(rng: &mut R, slice: &[T], amount: usize) -> Vec /// /// Panics if `amount > slice.len()` /// -/// # Example -/// -/// ``` -/// use rand::{thread_rng, seq}; -/// -/// let mut rng = thread_rng(); -/// let values = vec![5, 6, 1, 3, 4, 6, 7]; -/// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3)); -/// ``` +/// Deprecated: use [`SliceRandom::sample`] instead. +/// +/// [`SliceRandom::sample`]: trait.SliceRandom.html#method.sample +#[cfg(feature = "alloc")] +#[deprecated(since="0.6.0", note="use SliceRandom::sample instead")] pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized { @@ -132,6 +423,7 @@ pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> /// have the indices themselves so this is provided as an alternative. /// /// Panics if `amount > length` +#[cfg(feature = "alloc")] pub fn sample_indices(rng: &mut R, length: usize, amount: usize) -> Vec where R: Rng + ?Sized, { @@ -165,6 +457,7 @@ pub fn sample_indices(rng: &mut R, length: usize, amount: usize) -> Vec= length / 2` /// since it does not require allocating an extra cache and is much faster. +#[cfg(feature = "alloc")] fn sample_indices_inplace(rng: &mut R, length: usize, amount: usize) -> Vec where R: Rng + ?Sized, { @@ -186,6 +479,7 @@ fn sample_indices_inplace(rng: &mut R, length: usize, amount: usize) -> Vec( rng: &mut R, length: usize, @@ -222,22 +516,206 @@ fn sample_indices_cache( out } +/// Return a random element from `items` where the chance of a given item +/// being picked is proportional to the corresponding value in `weights`. +/// `weights` and `items` must return exactly the same number of values. +/// +/// All values returned by `weights` must be `>= 0`. +/// +/// This function iterates over `weights` twice. Once to get the total +/// weight, and once while choosing the random value. If you know the total +/// weight, or plan to call this function multiple times, you should +/// consider using [`choose_weighted_with_total`] instead. +/// +/// Return `None` if `items` and `weights` is empty. +/// +/// # Panics +/// +/// If a value in `weights < 0`. +/// If `weights` and `items` return different number of items. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rand::seq::choose_weighted; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let mut rng = thread_rng(); +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choose_weighted(&mut rng, choices.iter(), weights.iter().cloned()).unwrap()); +/// ``` +/// [`choose_weighted_with_total`]: seq/fn.choose_weighted_with_total.html +pub fn choose_weighted(rng: &mut R, + items: IterItems, + weights: IterWeights) -> Option + where R: Rng + ?Sized, + IterItems: Iterator, + IterWeights: Iterator + Clone, + IterWeights::Item: SampleUniform + + Default + + ::core::ops::Add + + ::core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + let total_weight: IterWeights::Item = + weights.clone().fold(Default::default(), |acc, w| { + assert!(w >= Default::default(), "Weight must be larger than zero"); + acc + w + }); + choose_weighted_with_total(rng, items, weights, total_weight) +} + +/// Return a random element from `items` where. The chance of a given item +/// being picked, is proportional to the corresponding value in `weights`. +/// `weights` and `items` must return exactly the same number of values. +/// +/// All values returned by `weights` must be `>= 0`. +/// +/// `total_weight` must be exactly the sum of all values returned by +/// `weights`. Builds with debug_assertions turned on will assert that this +/// equality holds. Simply storing the result of `weights.sum()` and using +/// that as `total_weight` should work. +/// +/// Return `None` if `items` and `weights` is empty. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rand::seq::choose_weighted_with_total; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let mut rng = thread_rng(); +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choose_weighted_with_total(&mut rng, choices.iter(), weights.iter().cloned(), 4).unwrap()); +/// ``` +pub fn choose_weighted_with_total(rng: &mut R, + mut items: IterItems, + mut weights: IterWeights, + total_weight: IterWeights::Item) -> Option + where R: Rng + ?Sized, + IterItems: Iterator, + IterWeights: Iterator, + IterWeights::Item: SampleUniform + + Default + + ::core::ops::Add + + ::core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + + if total_weight == Default::default() { + debug_assert!(items.next().is_none()); + return None; + } + + // Only used when debug_assertions are turned on + let mut debug_result = None; + let debug_total_weight = if cfg!(debug_assertions) { Some(total_weight.clone()) } else { None }; + + let chosen_weight = rng.gen_range(Default::default(), total_weight); + let mut cumulative_weight: IterWeights::Item = Default::default(); + + for item in items { + let weight_opt = weights.next(); + assert!(weight_opt.is_some(), "`weights` returned fewer items than `items` did"); + let weight = weight_opt.unwrap(); + assert!(weight >= Default::default(), "Weight must be larger than zero"); + + cumulative_weight = cumulative_weight + weight; + + if cumulative_weight > chosen_weight { + if !cfg!(debug_assertions) { + return Some(item); + } + if debug_result.is_none() { + debug_result = Some(item); + } + } + } + + assert!(weights.next().is_none(), "`weights` returned more items than `items` did"); + debug_assert!(debug_total_weight.unwrap() == cumulative_weight); + if cfg!(debug_assertions) && debug_result.is_some() { + return debug_result; + } + + panic!("total_weight did not match up with sum of weights"); +} + #[cfg(test)] mod test { use super::*; + use super::IteratorRandom; + #[cfg(feature = "alloc")] use {XorShiftRng, Rng, SeedableRng}; - #[cfg(not(feature="std"))] + #[cfg(all(feature="alloc", not(feature="std")))] use alloc::Vec; + #[cfg(feature="std")] + use core::panic::catch_unwind; + + #[test] + fn test_choose() { + let mut r = ::test::rng(107); + assert_eq!([1, 1, 1].choose(&mut r), Some(&1)); + + let mut v = [2]; + v.choose_mut(&mut r).map(|x| *x = 5); + assert_eq!(v[0], 5); + + let v = [3, 3, 3, 3]; + assert_eq!(v.iter().choose(&mut r), Some(&3)); + + let v: &[isize] = &[]; + assert_eq!(v.choose(&mut r), None); + } #[test] + fn test_shuffle() { + let mut r = ::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + let mut x = [1, 1, 1]; + x.shuffle(&mut r); + let b: &[_] = &[1, 1, 1]; + assert_eq!(x, b); + } + + #[test] + fn test_partial_shuffle() { + let mut r = ::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } + + #[test] + #[cfg(feature = "alloc")] fn test_sample_iter() { let min_val = 1; let max_val = 100; let mut r = ::test::rng(401); let vals = (min_val..max_val).collect::>(); - let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap(); - let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err(); + let small_sample = vals.iter().sample(&mut r, 5); + let large_sample = vals.iter().sample(&mut r, vals.len() + 5); assert_eq!(small_sample.len(), 5); assert_eq!(large_sample.len(), vals.len()); @@ -249,6 +727,8 @@ mod test { })); } #[test] + #[cfg(feature = "alloc")] + #[allow(deprecated)] fn test_sample_slice_boundaries() { let empty: &[u8] = &[]; @@ -293,6 +773,8 @@ mod test { } #[test] + #[cfg(feature = "alloc")] + #[allow(deprecated)] fn test_sample_slice() { let xor_rng = XorShiftRng::from_seed; @@ -332,4 +814,108 @@ mod test { } } } + + #[test] + fn test_choose_weighted() { + let mut r = ::test::rng(406); + let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum(); + assert_eq!(chars.len(), weights.len()); + + let mut chosen = [0i32; 14]; + for _ in 0..1000 { + let picked = *choose_weighted(&mut r, + chars.iter(), + weights.iter().cloned()).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // Mutable items + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *choose_weighted(&mut r, + chosen.iter_mut(), + weights.iter().cloned()).unwrap() += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // choose_weighted_with_total + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + let picked = *choose_weighted_with_total(&mut r, + chars.iter(), + weights.iter().cloned(), + total_weight).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + } + + #[test] + #[cfg(all(feature="std", + not(target_arch = "wasm32"), + not(target_arch = "asmjs")))] + fn test_choose_weighted_assertions() { + fn inner_delta(delta: i32) { + let items = vec![1, 2, 3]; + let mut r = ::test::rng(407); + if cfg!(debug_assertions) || delta == 0 { + choose_weighted_with_total(&mut r, + items.iter(), + items.iter().cloned(), + 6+delta); + } else { + loop { + choose_weighted_with_total(&mut r, + items.iter(), + items.iter().cloned(), + 6+delta); + } + } + } + + assert!(catch_unwind(|| inner_delta(0)).is_ok()); + assert!(catch_unwind(|| inner_delta(1)).is_err()); + assert!(catch_unwind(|| inner_delta(1000)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too small total_weight + assert!(catch_unwind(|| inner_delta(-1)).is_err()); + assert!(catch_unwind(|| inner_delta(-1000)).is_err()); + } + + fn inner_size(items: usize, weights: usize, with_total: bool) { + let mut r = ::test::rng(408); + if with_total { + choose_weighted_with_total(&mut r, + ::core::iter::repeat(1usize).take(items), + ::core::iter::repeat(1usize).take(weights), + weights); + } else { + choose_weighted(&mut r, + ::core::iter::repeat(1usize).take(items), + ::core::iter::repeat(1usize).take(weights)); + } + } + + assert!(catch_unwind(|| inner_size(2, 2, true)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 2, false)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 1, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 1, false)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too many weights + assert!(catch_unwind(|| inner_size(2, 3, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 3, false)).is_err()); + } + } }