From be1158de0acd51d008e3178cd6fbb97e7972d7ae Mon Sep 17 00:00:00 2001 From: Jonas Sicking Date: Tue, 21 Aug 2018 00:34:16 -0700 Subject: [PATCH] Use Iterator::size_hint() to speed up IteratorRandom::choose --- benches/seq.rs | 73 +++++++++++++++++++++++-- src/seq/mod.rs | 142 ++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 190 insertions(+), 25 deletions(-) diff --git a/benches/seq.rs b/benches/seq.rs index f143131763b..04a76caf41e 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -8,6 +8,9 @@ use test::Bencher; use rand::prelude::*; use rand::seq::*; +use std::mem::size_of; + +const RAND_BENCH_N: u64 = 1000; #[bench] fn seq_shuffle_100(b: &mut Bencher) { @@ -22,10 +25,18 @@ fn seq_shuffle_100(b: &mut Bencher) { #[bench] fn seq_slice_choose_1_of_1000(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x : &[usize] = &[1; 1000]; + let x : &mut [usize] = &mut [1; 1000]; + for i in 0..1000 { + x[i] = i; + } b.iter(|| { - x.choose(&mut rng) - }) + let mut s = 0; + for _ in 0..RAND_BENCH_N { + s += x.choose(&mut rng).unwrap(); + } + s + }); + b.bytes = size_of::() as u64 * ::RAND_BENCH_N; } macro_rules! seq_slice_choose_multiple { @@ -54,11 +65,63 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100); seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100); #[bench] -fn seq_iter_choose_from_100(b: &mut Bencher) { +fn seq_iter_choose_from_1000(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &mut [usize] = &mut [1; 1000]; + for i in 0..1000 { + x[i] = i; + } + b.iter(|| { + let mut s = 0; + for _ in 0..RAND_BENCH_N { + s += x.iter().choose(&mut rng).unwrap(); + } + s + }); + b.bytes = size_of::() as u64 * ::RAND_BENCH_N; +} + +#[derive(Clone)] +struct UnhintedIterator { + iter: I, +} +impl Iterator for UnhintedIterator { + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[derive(Clone)] +struct WindowHintedIterator { + iter: I, + window_size: usize, +} +impl Iterator for WindowHintedIterator { + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next() + } + fn size_hint(&self) -> (usize, Option) { + (std::cmp::min(self.iter.len(), self.window_size), None) + } +} + +#[bench] +fn seq_iter_unhinted_choose_from_100(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + let x : &[usize] = &[1; 1000]; + b.iter(|| { + UnhintedIterator { iter: x.iter() }.choose(&mut rng).unwrap() + }) +} + +#[bench] +fn seq_iter_window_hinted_choose_from_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let x : &[usize] = &[1; 100]; b.iter(|| { - x.iter().cloned().choose(&mut rng) + WindowHintedIterator { iter: x.iter(), window_size: 7 }.choose(&mut rng) }) } diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 4e06bac2863..bea2d1bf863 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -188,20 +188,57 @@ pub trait IteratorRandom: Iterator + Sized { fn choose(mut self, rng: &mut R) -> Option where R: Rng + ?Sized { - if let Some(elem) = self.next() { - let mut result = elem; - - // Continue until the iterator is exhausted - for (i, elem) in self.enumerate() { - let denom = (i + 2) as f64; // accurate to 2^53 elements + let (mut lower, mut upper) = self.size_hint(); + let mut consumed = 0; + let mut result = None; + + if upper == Some(lower) { + // Remove this once we can specialize on ExactSizeIterator + return if lower == 0 { None } else { self.nth(rng.gen_range(0, lower)) }; + } else if lower <= 1 { + result = self.next(); + if result.is_none() { + return result; + } + consumed = 1; + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; + } + + // Continue until the iterator is exhausted + loop { + if lower > 1 { + let ix = rng.gen_range(0, lower + consumed); + let skip; + if ix < lower { + result = self.nth(ix); + skip = lower - (ix + 1); + } else { + skip = lower; + } + if upper == Some(lower) { + return result; + } + consumed += lower; + if skip > 0 { + self.nth(skip - 1); + } + } else { + let elem = self.next(); + if elem.is_none() { + return result; + } + consumed += 1; + let denom = consumed as f64; // accurate to 2^53 elements if rng.gen_bool(1.0 / denom) { result = elem; } } - - Some(result) - } else { - None + + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; } } @@ -519,20 +556,85 @@ mod test { assert_eq!(v.choose_mut(&mut r), None); } + #[derive(Clone)] + struct UnhintedIterator { + iter: I, + } + impl Iterator for UnhintedIterator { + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next() + } + } + + #[derive(Clone)] + struct ChunkHintedIterator { + iter: I, + chunk_remaining: usize, + chunk_size: usize, + hint_total_size: bool, + } + impl Iterator for ChunkHintedIterator { + type Item = I::Item; + fn next(&mut self) -> Option { + if self.chunk_remaining == 0 { + self.chunk_remaining = ::core::cmp::min(self.chunk_size, + self.iter.len()); + } + self.chunk_remaining = self.chunk_remaining.saturating_sub(1); + + self.iter.next() + } + fn size_hint(&self) -> (usize, Option) { + (self.chunk_remaining, + if self.hint_total_size { Some(self.iter.len()) } else { None }) + } + } + + #[derive(Clone)] + struct WindowHintedIterator { + iter: I, + window_size: usize, + hint_total_size: bool, + } + impl Iterator for WindowHintedIterator { + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next() + } + fn size_hint(&self) -> (usize, Option) { + (::core::cmp::min(self.iter.len(), self.window_size), + if self.hint_total_size { Some(self.iter.len()) } else { None }) + } + } + #[test] fn test_iterator_choose() { - let mut r = ::test::rng(109); - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = (0..9).choose(&mut r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - let err = *count - 1000 / 9; - assert!(-25 <= err && err <= 25); + let r = &mut ::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + let err = *count - 1000 / 9; + assert!(-25 <= err && err <= 25); + } } - assert_eq!((0..0).choose(&mut r), None); + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: false }); + test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: true }); + test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: false }); + test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: true }); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator{ iter: 0..0 }.choose(r), None); } #[test]