diff --git a/src/lib.rs b/src/lib.rs index 6d6d762b3cb..16be7518521 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,6 +269,7 @@ pub mod prelude; pub mod prng; pub mod rngs; #[cfg(feature = "alloc")] pub mod seq; +#[cfg(feature = "alloc")] pub use seq::{SliceRandom, IteratorRandom}; //////////////////////////////////////////////////////////////////////////////// // Compatibility re-exports. Documentation is hidden; will be removed eventually. @@ -595,6 +596,161 @@ pub trait Rng: RngCore { } } + /// Returns one random element of the `Iterator`, or `None` if the + /// `Iterator` returns no items. If you have a slice, it's significantly + /// faster to call the [`choose`] or [`choose_mut`] functions using the + /// slice instead. However it expected to be faster than dumping the + /// Iterator into a slice and then calling [`choose`]/[`choose_mut`] on + /// the slice. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let choices = std::iter::repeat(0) + /// .scan((1, 1), |state, _| { let (a, b) = *state; *state = (b, a+b); Some(a) }) + /// .take(40); + /// let mut rng = thread_rng(); + /// // Randomly choose one of the first 40 fibonacci numbers + /// println!("{}", rng.choose_from_iterator(choices).unwrap()); + /// assert_eq!(rng.choose_from_iterator(std::iter::empty::()), None); + /// ``` + /// [`choose`]: trait.Rng.html#method.choose + /// [`choose_mut`]: trait.Rng.html#method.choose_mut + fn choose_from_iterator(&mut self, mut iterable: I) -> Option { + let mut val = iterable.next(); + if val.is_none() { + return val; + } + + for (i, elem) in iterable.enumerate() { + if self.gen_range(0, i + 2) == 0 { + val = Some(elem); + } + } + val + } + + /// Return a random element from `items` where. The chance of a given item + /// being picked, is proportional to the corresponding value in `weights`. + /// `weights` and `items` must return exactly the same number of values. + /// + /// All values returned by `weights` must be `>= 0`. + /// + /// This function iterates over `weights` twice. Once to get the total + /// weight, and once while choosing the random value. If you know the total + /// weight, or plan to call this function multiple times, you should + /// consider using [`choose_weighted_with_total`] instead. + /// + /// Return `None` if `items` and `weights` is empty. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let choices = ['a', 'b', 'c']; + /// let weights = [2, 1, 1]; + /// let mut rng = thread_rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' + /// println!("{}", rng.choose_weighted(choices.iter(), weights.iter().cloned()).unwrap()); + /// ``` + /// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total + fn choose_weighted(&mut self, + items: IterItems, + weights: IterWeights) -> Option + where IterItems: Iterator, + IterWeights: Iterator+Clone, + IterWeights::Item: SampleUniform + + Default + + core::ops::Add + + core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + let total_weight: IterWeights::Item = + weights.clone().fold(Default::default(), |acc, w| { + assert!(w >= Default::default(), "Weight must be larger than zero"); + acc + w + }); + self.choose_weighted_with_total(items, weights, total_weight) + } + + /// Return a random element from `items` where. The chance of a given item + /// being picked, is proportional to the corresponding value in `weights`. + /// `weights` and `items` must return exactly the same number of values. + /// + /// All values returned by `weights` must be `>= 0`. + /// + /// `total_weight` must be exactly the sum of all values returned by + /// `weights`. Builds with debug_assertions turned on will assert that this + /// equality holds. Simply storing the result of `weights.sum()` and using + /// that as `total_weight` should work. + /// + /// Return `None` if `items` and `weights` is empty. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let choices = ['a', 'b', 'c']; + /// let weights = [2, 1, 1]; + /// let mut rng = thread_rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' + /// println!("{}", rng.choose_weighted_with_total(choices.iter(), weights.iter().cloned(), 4).unwrap()); + /// ``` + /// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total + fn choose_weighted_with_total(&mut self, + mut items: IterItems, + mut weights: IterWeights, + total_weight: IterWeights::Item) -> Option + where IterItems: Iterator, + IterWeights: Iterator, + IterWeights::Item: SampleUniform + + Default + + core::ops::Add + + core::cmp::PartialOrd + + Clone { // Clone is only needed for debug assertions + + if total_weight == Default::default() { + debug_assert!(items.next().is_none()); + return None; + } + + // Only used when debug_assertions are turned on + let mut debug_result = None; + let debug_total_weight = if cfg!(debug_assertions) { Some(total_weight.clone()) } else { None }; + + let chosen_weight = self.gen_range(Default::default(), total_weight); + let mut cumulative_weight: IterWeights::Item = Default::default(); + + for item in items { + let weight_opt = weights.next(); + assert!(weight_opt.is_some(), "`weights` returned fewer items than `items` did"); + let weight = weight_opt.unwrap(); + assert!(weight >= Default::default(), "Weight must be larger than zero"); + + cumulative_weight = cumulative_weight + weight; + + if cumulative_weight > chosen_weight { + if !cfg!(debug_assertions) { + return Some(item); + } + if debug_result.is_none() { + debug_result = Some(item); + } + } + } + + assert!(weights.next().is_none(), "`weights` returned more items than `items` did"); + debug_assert!(debug_total_weight.unwrap() == cumulative_weight); + if cfg!(debug_assertions) && debug_result.is_some() { + return debug_result; + } + + panic!("total_weight did not match up with sum of weights"); + } + /// Shuffle a mutable slice in place. /// /// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle]( @@ -846,6 +1002,7 @@ pub fn random() -> T where Standard: Distribution { #[cfg(test)] mod test { use rngs::mock::StepRng; + #[cfg(feature="std")] use core::panic::catch_unwind; use super::*; #[cfg(all(not(feature="std"), feature="alloc"))] use alloc::boxed::Box; @@ -976,15 +1133,50 @@ mod test { #[test] fn test_choose() { let mut r = rng(107); - assert_eq!(r.choose(&[1, 1, 1]).map(|&x|x), Some(1)); + let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']; + let mut chosen = [0i32; 14]; + for _ in 0..1000 { + let picked = *r.choose(&chars).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + let err = *count - (1000 / (chars.len() as i32)); + assert!(-20 <= err && err <= 20); + } - let v: &[isize] = &[]; - assert_eq!(r.choose(v), None); + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *r.choose_mut(&mut chosen).unwrap() += 1; + } + for count in chosen.iter() { + let err = *count - (1000 / (chosen.len() as i32)); + assert!(-20 <= err && err <= 20); + } + + let mut v: [isize; 0] = []; + assert_eq!(r.choose(&v), None); + assert_eq!(r.choose_mut(&mut v), None); } #[test] - fn test_shuffle() { + fn test_choose_from_iterator() { let mut r = rng(108); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = r.choose_from_iterator(0..9).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + let err = *count - 1000 / 9; + assert!(-25 <= err && err <= 25); + } + + assert_eq!(r.choose_from_iterator(0..0), None); + } + + #[test] + fn test_shuffle() { + let mut r = rng(109); let empty: &mut [isize] = &mut []; r.shuffle(empty); let mut one = [1]; @@ -1005,7 +1197,7 @@ mod test { #[test] fn test_rng_trait_object() { use distributions::{Distribution, Standard}; - let mut rng = rng(109); + let mut rng = rng(110); let mut r = &mut rng as &mut RngCore; r.next_u32(); r.gen::(); @@ -1021,7 +1213,7 @@ mod test { #[cfg(feature="alloc")] fn test_rng_boxed_trait() { use distributions::{Distribution, Standard}; - let rng = rng(110); + let rng = rng(111); let mut r = Box::new(rng) as Box; r.next_u32(); r.gen::(); @@ -1049,6 +1241,7 @@ mod test { } #[test] +<<<<<<< HEAD fn test_gen_ratio_average() { const NUM: u32 = 3; const DENOM: u32 = 10; @@ -1063,5 +1256,101 @@ mod test { } let avg = (sum as f64) / (N as f64); assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3); +======= + fn test_choose_weighted() { + let mut r = rng(112); + let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum(); + assert_eq!(chars.len(), weights.len()); + + let mut chosen = [0i32; 14]; + for _ in 0..1000 { + let picked = *r.choose_weighted(chars.iter(), + weights.iter().cloned()).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // Mutable items + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *r.choose_weighted(chosen.iter_mut(), + weights.iter().cloned()).unwrap() += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + + // choose_weighted_with_total + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + let picked = *r.choose_weighted_with_total(chars.iter(), + weights.iter().cloned(), + total_weight).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for (i, count) in chosen.iter().enumerate() { + let err = *count - ((weights[i] * 1000 / total_weight) as i32); + assert!(-25 <= err && err <= 25); + } + } + + #[test] + #[cfg(all(feature="std", + not(target_arch = "wasm32"), + not(target_arch = "asmjs")))] + fn test_choose_weighted_assertions() { + fn inner_delta(delta: i32) { + let items = vec![1, 2, 3]; + let mut r = rng(113); + if cfg!(debug_assertions) || delta == 0 { + r.choose_weighted_with_total(items.iter(), + items.iter().cloned(), + 6+delta); + } else { + loop { + r.choose_weighted_with_total(items.iter(), + items.iter().cloned(), + 6+delta); + } + } + } + + assert!(catch_unwind(|| inner_delta(0)).is_ok()); + assert!(catch_unwind(|| inner_delta(1)).is_err()); + assert!(catch_unwind(|| inner_delta(1000)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too small total_weight + assert!(catch_unwind(|| inner_delta(-1)).is_err()); + assert!(catch_unwind(|| inner_delta(-1000)).is_err()); + } + + fn inner_size(items: usize, weights: usize, with_total: bool) { + let mut r = rng(114); + if with_total { + r.choose_weighted_with_total(core::iter::repeat(1usize).take(items), + core::iter::repeat(1usize).take(weights), + weights); + } else { + r.choose_weighted(core::iter::repeat(1usize).take(items), + core::iter::repeat(1usize).take(weights)); + } + } + + assert!(catch_unwind(|| inner_size(2, 2, true)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 2, false)).is_ok()); + assert!(catch_unwind(|| inner_size(2, 1, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 1, false)).is_err()); + if cfg!(debug_assertions) { + // The non-debug-assertions code can't detect too many weights + assert!(catch_unwind(|| inner_size(2, 3, true)).is_err()); + assert!(catch_unwind(|| inner_size(2, 3, false)).is_err()); + } +>>>>>>> Implement choose/choose_mut/choose_from_iterator on both Rng and on slice/Iterator } } diff --git a/src/seq.rs b/src/seq.rs index 68f7ab08edc..2d3cedae09d 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -19,6 +19,60 @@ use super::Rng; #[cfg(not(feature="std"))] use alloc::Vec; +/// Trait to provide convenience function directly on slices rather than +/// through [`Rng`]. Allows for example calling `vec.choose(&mut rng)` rather +/// than `rng.choose(&vec)`. +/// [`Rng`]: trait.Rng.html +pub trait SliceRandom { + #[doc(hidden)] + type Item; + + /// Same as [`Rng.choose`]. + /// [`Rng.choose`]: trait.Rng.html#method.choose + fn choose(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized; + + /// Same as [`Rng.choose_mut`]. + /// [`Rng.choose_mut`]: trait.Rng.html#method.choose_mut + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized; +} + +impl SliceRandom for [T] { + type Item = T; + + fn choose<'a, R>(&'a self, rng: &mut R) -> Option<&'a Self::Item> + where R: Rng + ?Sized + { + rng.choose(self) + } + + fn choose_mut<'a, R>(&'a mut self, rng: &mut R) -> Option<&'a mut Self::Item> + where R: Rng + ?Sized + { + rng.choose_mut(self) + } +} + +/// XXX Any way to avoid '+ Sized' here? +/// Trait to provide convenience function directly on [`Iterator`]s rather +/// than through [`Rng`]. Allows for example calling `iter.choose(&mut rng)` +/// rather than `rng.choose_from_iterator(iter)`. +/// [`Rng`]: trait.Rng.html +/// ['Iterator']: trait.Iterator.html +pub trait IteratorRandom: Iterator + Sized { + /// Same as [`Rng.choose_from_iterator`]. + /// [`Rng.choose_from_iterator`]: trait.Rng.html#method.choose_from_iterator + fn choose(self, rng: &mut R) -> Option + where R: Rng + ?Sized + { + rng.choose_from_iterator(self) + } +} + +impl IteratorRandom for I where I: Iterator + Sized {} + + /// Randomly sample `amount` elements from a finite iterator. /// /// The following can be returned: @@ -332,4 +386,51 @@ mod test { } } } + + #[test] + fn test_choose() { + let mut r = ::test::rng(404); + let chars = "abcdefghijklmn".chars().collect::>(); + let mut chosen = Vec::new(); + chosen.resize(chars.len(), 0i32); + for _ in 0..1000 { + let picked = *chars.choose(&mut r).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + let err = *count - (1000 / (chars.len() as i32)); + assert!(-25 <= err && err <= 25); + } + + chosen.truncate(0); + chosen.resize(8, 0i32); + for _ in 0..1000 { + *chosen.choose_mut(&mut r).unwrap() += 1; + } + for count in chosen.iter() { + let err = *count - (1000 / (chosen.len() as i32)); + assert!(-25 <= err && err <= 25); + } + + let mut v: [isize; 0] = []; + assert_eq!(r.choose(&v), None); + assert_eq!(r.choose_mut(&mut v), None); + } + + #[test] + fn test_iterator_choose() { + let mut r = ::test::rng(405); + let mut chosen = Vec::new(); + chosen.resize(9, 0i32); + for _ in 0..1000 { + let picked = (0..9).choose(&mut r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + let err = *count - 1000 / 9; + assert!(-20 <= err && err <= 20); + } + + assert_eq!((0..0).choose(&mut r), None); + } }