From ca7de51826ffd05a0ebd06850ffde9f1bf09e9d5 Mon Sep 17 00:00:00 2001 From: Zachary Neely Date: Sun, 10 May 2020 19:47:00 -0700 Subject: [PATCH 01/11] Add choose_multiple_weighted, tests, and benchmarks --- Cargo.toml | 5 +- benches/seq.rs | 21 ++++++ src/lib.rs | 1 + src/seq/index.rs | 119 +++++++++++++++++++++++++++++-- src/seq/mod.rs | 180 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 321 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c0d8cba184a..d3ea9cf2255 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ appveyor = { repository = "rust-random/rand" } [features] # Meta-features: default = ["std", "std_rng"] -nightly = ["simd_support"] # enables all features requiring nightly rust +nightly = ["simd_support", "partition_at_index"] # enables all features requiring nightly rust serde1 = ["serde"] # Option (enabled by default): without "std" rand uses libcore; this option @@ -45,6 +45,9 @@ std_rng = ["rand_chacha", "rand_hc"] # Option: enable SmallRng small_rng = ["rand_pcg"] +# Option (requires nightly): better performance of choose_multiple_weighted +partition_at_index = [] + [workspace] members = [ "rand_core", diff --git a/benches/seq.rs b/benches/seq.rs index 7da2ff8a0fd..5b6fccf51ee 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -177,3 +177,24 @@ sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1000_000_000); sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1000_000_000); sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1000_000_000); sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1000_000_000); + +macro_rules! sample_indices_rand_weights { + ($name:ident, $amount:expr, $length:expr) => { + #[bench] + fn $name(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + b.iter(|| { + index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount) + }) + } + }; +} + +sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1000_000); diff --git a/src/lib.rs b/src/lib.rs index 061be1b1ae7..02c96f53807 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))] +#![cfg_attr(all(feature = "partition_at_index", feature = "nightly"), feature(slice_partition_at_index))] #![cfg_attr(doc_cfg, feature(doc_cfg))] #![allow( clippy::excessive_precision, diff --git a/src/seq/index.rs b/src/seq/index.rs index 0ab5aec20ef..5e76af19dca 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -8,18 +8,21 @@ //! Low-level API for sampling indices -#[cfg(feature = "alloc")] use core::slice; +#[cfg(feature = "alloc")] +use core::slice; #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec::{self, Vec}; -#[cfg(feature = "std")] use std::vec; +#[cfg(feature = "std")] +use std::vec; // BTreeMap is not as fast in tests, but better than nothing. #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::collections::BTreeSet; -#[cfg(feature = "std")] use std::collections::HashSet; +#[cfg(feature = "std")] +use std::collections::HashSet; #[cfg(feature = "alloc")] -use crate::distributions::{uniform::SampleUniform, Distribution, Uniform}; +use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError}; use crate::Rng; #[cfg(feature = "serde1")] @@ -258,6 +261,114 @@ where R: Rng + ?Sized { } } +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in an arbitrary order (there is no guarantee of shuffling or +/// ordering). The weights are to be provided by the input function `weights`, +/// which will be called once for each index. +/// +/// This method is used internally by the slice sampling methods, but it can +/// sometimes be useful to have the indices themselves so this is provided as +/// an alternative. +/// +/// This implementation uses `O(length)` space and `O(length)` time if the +/// "partition_at_index" feature is enabled, or `O(length)` space and +/// `O(length * log amount)` time otherwise. +/// +/// Panics if `amount > length`. +pub fn sample_weighted( + rng: &mut R, length: usize, weight: F, amount: usize, +) -> Result +where + R: Rng + ?Sized, + F: Fn(usize) -> X, + X: Into, +{ + if amount > length { + panic!("`amount` of samples must be less than or equal to `length`"); + } + + // This implementation uses the algorithm described by Efraimidis and Spirakis + // in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 + + struct Element { + index: usize, + key: f64, + } + impl PartialOrd for Element { + fn partial_cmp(&self, other: &Self) -> Option { + self.key + .partial_cmp(&other.key) + .or(Some(std::cmp::Ordering::Less)) + } + } + impl Ord for Element { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() // partial_cmp will always produce a value + } + } + impl PartialEq for Element { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } + } + impl Eq for Element {} + + #[cfg(feature = "partition_at_index")] + { + if length == 0 { + return Ok(IndexVecIntoIter::USize(Vec::new().into_iter())); + } + + let mut candidates = Vec::with_capacity(length); + for index in 0..length { + let weight = weight(index).into(); + if weight < 0.0 || weight.is_nan() { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::().powf(1.0 / weight); + candidates.push(Element { index, key }) + } + + // Partially sort the array to find the `amount` elements with the greatest + // keys. Do this by using `partition_at_index` to put the elements with + // the *smallest* keys at the beginning of the list in `O(n)` time, which + // provides equivalent information about the elements with the *greatest* keys. + let (_, mid, greater) = candidates.partition_at_index(length - amount); + + let mut result = Vec::with_capacity(amount); + result.push(mid.index); + for element in greater { + result.push(element.index); + } + Ok(IndexVecIntoIter::USize(result.into_iter())) + } + + #[cfg(not(feature = "partition_at_index"))] + { + use std::collections::BinaryHeap; + + // Partially sort the array such that the `amount` elements with the largest + // keys are first using a binary max heap. + let mut candidates = BinaryHeap::with_capacity(length); + for index in 0..length { + let weight = weight(index).into(); + if weight < 0.0 || weight.is_nan() { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::().powf(1.0 / weight); + candidates.push(Element { index, key }); + } + + let mut result = Vec::with_capacity(amount); + while result.len() < amount { + result.push(candidates.pop().unwrap().index); + } + Ok(IndexVecIntoIter::USize(result.into_iter())) + } +} + /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's /// combination algorithm. /// diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 234a49d9c2b..7eb3c7ec54d 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -181,6 +181,46 @@ pub trait SliceRandom { + Clone + Default; + /// Similar to [`choose_multiple`], but where the likelihood of each element's + /// inclusion in the output may be specified. The elements are returned in an + /// arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// If all of the weights are equal, even if they are all zero, each element has + /// an equal likelihood of being selected. + /// + /// The complexity of this method depends on the feature `partition_at_index`. + /// If the feature is enabled, then for slices of length `n`, the complexity + /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and + /// `O(n * log amount)` time. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = thread_rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); + /// ``` + /// [`choose_multiple`]: SliceRandom::choose_multiple + #[cfg(feature = "alloc")] + fn choose_multiple_weighted( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into; + /// Shuffle a mutable slice in place. /// /// For slices of length `n`, complexity is `O(n)`. @@ -452,6 +492,28 @@ impl SliceRandom for [T] { Ok(&mut self[distr.sample(rng)]) } + #[cfg(feature = "alloc")] + fn choose_multiple_weighted( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into, + { + let amount = ::core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )?, + }) + } + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized { for i in (1..self.len()).rev() { @@ -956,4 +1018,122 @@ mod test { do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); } } + + #[test] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + assert!(matches!( + choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), + Err(WeightedError::InvalidWeight) + )); + + // Case 4: Empty list + let choices = []; + let result = choices + .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 0); + + // Case 5: NaN weights + let choices = [('a', std::f64::NAN), ('b', 1.0), ('c', 1.0)]; + assert!(matches!( + choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), + Err(WeightedError::InvalidWeight) + )); + + // Case 6: +infinity weights + let choices = [('a', std::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', std::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + assert!(matches!( + choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), + Err(WeightedError::InvalidWeight) + )); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + assert!(choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .is_ok()); + } + + #[test] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.5 = 0.250 + // AC: 0.5 * 0.5 = 0.250 + // BA: 0.25 * 0.67 = 0.167 + // BC: 0.25 * 0.33 = 0.082 + // CA: 0.25 * 0.67 = 0.167 + // CB: 0.25 * 0.33 = 0.082 + let choices = [('a', 2), ('b', 1), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [4167, 4167, 1666]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } } From 5144b23076a4b3522cf4b16ebebd8b5252a56621 Mon Sep 17 00:00:00 2001 From: Zachary Neely Date: Sun, 10 May 2020 22:20:40 -0700 Subject: [PATCH 02/11] Mark new tests with the correct feature flags --- src/seq/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 7eb3c7ec54d..f9407f8e7ee 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1020,6 +1020,7 @@ mod test { } #[test] + #[cfg(feature = "alloc")] fn test_multiple_weighted_edge_cases() { use super::*; @@ -1093,6 +1094,7 @@ mod test { } #[test] + #[cfg(feature = "alloc")] fn test_multiple_weighted_distributions() { use super::*; From 512e2dda18efba1c121943e016a12842341882b6 Mon Sep 17 00:00:00 2001 From: Zachary Neely Date: Sun, 10 May 2020 22:32:51 -0700 Subject: [PATCH 03/11] Fix feature flags and use core for nostd compatibility --- src/seq/index.rs | 9 ++++++--- src/seq/mod.rs | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 5e76af19dca..6a226418241 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -295,14 +295,14 @@ where key: f64, } impl PartialOrd for Element { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { self.key .partial_cmp(&other.key) - .or(Some(std::cmp::Ordering::Less)) + .or(Some(core::cmp::Ordering::Less)) } } impl Ord for Element { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { self.partial_cmp(other).unwrap() // partial_cmp will always produce a value } } @@ -346,6 +346,9 @@ where #[cfg(not(feature = "partition_at_index"))] { + #[cfg(all(feature = "alloc", not(feature = "std")))] + use crate::alloc::collections::BinaryHeap; + #[cfg(feature = "std")] use std::collections::BinaryHeap; // Partially sort the array such that the `amount` elements with the largest diff --git a/src/seq/mod.rs b/src/seq/mod.rs index f9407f8e7ee..bdaa19a3111 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1062,14 +1062,14 @@ mod test { assert_eq!(result.len(), 0); // Case 5: NaN weights - let choices = [('a', std::f64::NAN), ('b', 1.0), ('c', 1.0)]; + let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; assert!(matches!( choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), Err(WeightedError::InvalidWeight) )); // Case 6: +infinity weights - let choices = [('a', std::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; for _ in 0..100 { let result = choices .choose_multiple_weighted(&mut rng, 2, |item| item.1) @@ -1080,7 +1080,7 @@ mod test { } // Case 7: -infinity weights - let choices = [('a', std::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; assert!(matches!( choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), Err(WeightedError::InvalidWeight) From 33bb79ac2e6aa4e59261f56620046ec2b5fcf10b Mon Sep 17 00:00:00 2001 From: Zachary Neely Date: Sun, 10 May 2020 23:18:19 -0700 Subject: [PATCH 04/11] Don't use rustc features newer than 1.32 --- src/seq/mod.rs | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/seq/mod.rs b/src/seq/mod.rs index bdaa19a3111..20b064be477 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1048,10 +1048,12 @@ mod test { // Case 3: Negative weights let choices = [('a', -1), ('b', 1), ('c', 1)]; - assert!(matches!( - choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), - Err(WeightedError::InvalidWeight) - )); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); // Case 4: Empty list let choices = []; @@ -1063,10 +1065,12 @@ mod test { // Case 5: NaN weights let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; - assert!(matches!( - choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), - Err(WeightedError::InvalidWeight) - )); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); // Case 6: +infinity weights let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; @@ -1081,10 +1085,12 @@ mod test { // Case 7: -infinity weights let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; - assert!(matches!( - choices.choose_multiple_weighted(&mut rng, 2, |item| item.1), - Err(WeightedError::InvalidWeight) - )); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); // Case 8: -0 weights let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; From 6751c9b80bc655da1b5ea1c2bdd0ce3e5b08ecca Mon Sep 17 00:00:00 2001 From: Zachary Neely Date: Mon, 11 May 2020 09:11:20 -0700 Subject: [PATCH 05/11] Don't use IntoIter when unnecessary and fix complexity documentation --- src/seq/index.rs | 14 +++++++------- src/seq/mod.rs | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 6a226418241..1c9af023789 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -270,14 +270,14 @@ where R: Rng + ?Sized { /// sometimes be useful to have the indices themselves so this is provided as /// an alternative. /// -/// This implementation uses `O(length)` space and `O(length)` time if the -/// "partition_at_index" feature is enabled, or `O(length)` space and -/// `O(length * log amount)` time otherwise. +/// This implementation uses `O(length + amount)` space and `O(length)` time +/// if the "partition_at_index" feature is enabled, or `O(length)` space and +/// `O(length + amount * log length)` time otherwise. /// /// Panics if `amount > length`. pub fn sample_weighted( rng: &mut R, length: usize, weight: F, amount: usize, -) -> Result +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, @@ -316,7 +316,7 @@ where #[cfg(feature = "partition_at_index")] { if length == 0 { - return Ok(IndexVecIntoIter::USize(Vec::new().into_iter())); + return Ok(IndexVec::USize(Vec::new())); } let mut candidates = Vec::with_capacity(length); @@ -341,7 +341,7 @@ where for element in greater { result.push(element.index); } - Ok(IndexVecIntoIter::USize(result.into_iter())) + Ok(IndexVec::USize(result)) } #[cfg(not(feature = "partition_at_index"))] @@ -368,7 +368,7 @@ where while result.len() < amount { result.push(candidates.pop().unwrap().index); } - Ok(IndexVecIntoIter::USize(result.into_iter())) + Ok(IndexVec::USize(result)) } } diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 20b064be477..6fac283caf9 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -510,7 +510,8 @@ impl SliceRandom for [T] { self.len(), |idx| weight(&self[idx]).into(), amount, - )?, + )? + .into_iter(), }) } From dadaae68f0501c8326173584c34a81f400fa1e3f Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 4 Aug 2020 20:14:10 +0200 Subject: [PATCH 06/11] Undo reformatting --- src/seq/index.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 1c9af023789..a8e83f7bbb4 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -8,18 +8,15 @@ //! Low-level API for sampling indices -#[cfg(feature = "alloc")] -use core::slice; +#[cfg(feature = "alloc")] use core::slice; #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec::{self, Vec}; -#[cfg(feature = "std")] -use std::vec; +#[cfg(feature = "std")] use std::vec; // BTreeMap is not as fast in tests, but better than nothing. #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::collections::BTreeSet; -#[cfg(feature = "std")] -use std::collections::HashSet; +#[cfg(feature = "std")] use std::collections::HashSet; #[cfg(feature = "alloc")] use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError}; From 51f28618c88f8ba597e224a8f56fa28b1689f916 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 4 Aug 2020 20:16:59 +0200 Subject: [PATCH 07/11] sample_weighted: Return early if amount is 0 --- src/seq/index.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index a8e83f7bbb4..39f5ba5f90c 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -280,6 +280,10 @@ where F: Fn(usize) -> X, X: Into, { + if amount == 0 { + return Ok(IndexVec::USize(Vec::new())); + } + if amount > length { panic!("`amount` of samples must be less than or equal to `length`"); } @@ -312,10 +316,6 @@ where #[cfg(feature = "partition_at_index")] { - if length == 0 { - return Ok(IndexVec::USize(Vec::new())); - } - let mut candidates = Vec::with_capacity(length); for index in 0..length { let weight = weight(index).into(); From 6d922a0fff72d7a70f7f1e90dd431b5a629ae238 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 4 Aug 2020 20:22:05 +0200 Subject: [PATCH 08/11] Address review feedback --- src/seq/index.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 39f5ba5f90c..45878d45e0a 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -297,14 +297,14 @@ where } impl PartialOrd for Element { fn partial_cmp(&self, other: &Self) -> Option { - self.key - .partial_cmp(&other.key) - .or(Some(core::cmp::Ordering::Less)) + self.key.partial_cmp(&other.key) } } impl Ord for Element { fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.partial_cmp(other).unwrap() // partial_cmp will always produce a value + // partial_cmp will always produce a value, + // because we check that the weights are not nan + self.partial_cmp(other).unwrap() } } impl PartialEq for Element { @@ -319,7 +319,7 @@ where let mut candidates = Vec::with_capacity(length); for index in 0..length { let weight = weight(index).into(); - if weight < 0.0 || weight.is_nan() { + if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); } From cb021c55ab4527c56b490343cd97cfb84c26a2b5 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 4 Aug 2020 20:49:44 +0200 Subject: [PATCH 09/11] sample_weighted: Use less memory for `length <= u32::MAX` --- src/seq/index.rs | 71 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 45878d45e0a..5e8a7c7d6b4 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -280,44 +280,72 @@ where F: Fn(usize) -> X, X: Into, { - if amount == 0 { - return Ok(IndexVec::USize(Vec::new())); + if length > (::core::u32::MAX as usize) { + sample_efraimidis_spirakis(rng, length, weight, amount) + } else { + let amount = amount as u32; + let length = length as u32; + sample_efraimidis_spirakis(rng, length, weight, amount) + } +} + + +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in an arbitrary order (there is no guarantee of shuffling or +/// ordering). The weights are to be provided by the input function `weights`, +/// which will be called once for each index. +/// +/// This implementation uses the algorithm described by Efraimidis and Spirakis +/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 +/// It uses `O(length + amount)` space and `O(length)` time if the +/// "partition_at_index" feature is enabled, or `O(length)` space and `O(length +/// + amount * log length)` time otherwise. +/// +/// Panics if `amount > length`. +fn sample_efraimidis_spirakis( + rng: &mut R, length: N, weight: F, amount: N, +) -> Result +where + R: Rng + ?Sized, + F: Fn(usize) -> X, + X: Into, + N: UInt, +{ + if amount == N::zero() { + return Ok(IndexVec::U32(Vec::new())); } if amount > length { panic!("`amount` of samples must be less than or equal to `length`"); } - // This implementation uses the algorithm described by Efraimidis and Spirakis - // in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 - - struct Element { - index: usize, + struct Element { + index: N, key: f64, } - impl PartialOrd for Element { + impl PartialOrd for Element { fn partial_cmp(&self, other: &Self) -> Option { self.key.partial_cmp(&other.key) } } - impl Ord for Element { + impl Ord for Element { fn cmp(&self, other: &Self) -> core::cmp::Ordering { // partial_cmp will always produce a value, // because we check that the weights are not nan self.partial_cmp(other).unwrap() } } - impl PartialEq for Element { + impl PartialEq for Element { fn eq(&self, other: &Self) -> bool { self.key == other.key } } - impl Eq for Element {} + impl Eq for Element {} #[cfg(feature = "partition_at_index")] { - let mut candidates = Vec::with_capacity(length); - for index in 0..length { + let mut candidates = Vec::with_capacity(length.as_usize()); + for index in 0..length.as_usize() { let weight = weight(index).into(); if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); @@ -331,14 +359,15 @@ where // keys. Do this by using `partition_at_index` to put the elements with // the *smallest* keys at the beginning of the list in `O(n)` time, which // provides equivalent information about the elements with the *greatest* keys. - let (_, mid, greater) = candidates.partition_at_index(length - amount); + let (_, mid, greater) + = candidates.partition_at_index(length.as_usize() - amount.as_usize()); - let mut result = Vec::with_capacity(amount); + let mut result = Vec::with_capacity(amount.as_usize()); result.push(mid.index); for element in greater { result.push(element.index); } - Ok(IndexVec::USize(result)) + Ok(IndexVec::from(result)) } #[cfg(not(feature = "partition_at_index"))] @@ -350,8 +379,8 @@ where // Partially sort the array such that the `amount` elements with the largest // keys are first using a binary max heap. - let mut candidates = BinaryHeap::with_capacity(length); - for index in 0..length { + let mut candidates = BinaryHeap::with_capacity(length.as_usize()); + for index in 0..length.as_usize() { let weight = weight(index).into(); if weight < 0.0 || weight.is_nan() { return Err(WeightedError::InvalidWeight); @@ -361,11 +390,11 @@ where candidates.push(Element { index, key }); } - let mut result = Vec::with_capacity(amount); - while result.len() < amount { + let mut result = Vec::with_capacity(amount.as_usize()); + while result.len() < amount.as_usize() { result.push(candidates.pop().unwrap().index); } - Ok(IndexVec::USize(result)) + Ok(IndexVec::from(result)) } } From 03bd82a982df05af36550fcc770e5ba75fe84346 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 4 Aug 2020 20:53:11 +0200 Subject: [PATCH 10/11] Get rid of "partition_add_index" feature It offers nothing over the "nightly" feature and makes testing more complicated. --- Cargo.toml | 5 +---- src/lib.rs | 2 +- src/seq/index.rs | 8 ++++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d3ea9cf2255..c0d8cba184a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ appveyor = { repository = "rust-random/rand" } [features] # Meta-features: default = ["std", "std_rng"] -nightly = ["simd_support", "partition_at_index"] # enables all features requiring nightly rust +nightly = ["simd_support"] # enables all features requiring nightly rust serde1 = ["serde"] # Option (enabled by default): without "std" rand uses libcore; this option @@ -45,9 +45,6 @@ std_rng = ["rand_chacha", "rand_hc"] # Option: enable SmallRng small_rng = ["rand_pcg"] -# Option (requires nightly): better performance of choose_multiple_weighted -partition_at_index = [] - [workspace] members = [ "rand_core", diff --git a/src/lib.rs b/src/lib.rs index 02c96f53807..ceb501f4b56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,7 +50,7 @@ #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))] -#![cfg_attr(all(feature = "partition_at_index", feature = "nightly"), feature(slice_partition_at_index))] +#![cfg_attr(feature = "nightly", feature(slice_partition_at_index))] #![cfg_attr(doc_cfg, feature(doc_cfg))] #![allow( clippy::excessive_precision, diff --git a/src/seq/index.rs b/src/seq/index.rs index 5e8a7c7d6b4..427e92641be 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -268,7 +268,7 @@ where R: Rng + ?Sized { /// an alternative. /// /// This implementation uses `O(length + amount)` space and `O(length)` time -/// if the "partition_at_index" feature is enabled, or `O(length)` space and +/// if the "nightly" feature is enabled, or `O(length)` space and /// `O(length + amount * log length)` time otherwise. /// /// Panics if `amount > length`. @@ -298,7 +298,7 @@ where /// This implementation uses the algorithm described by Efraimidis and Spirakis /// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 /// It uses `O(length + amount)` space and `O(length)` time if the -/// "partition_at_index" feature is enabled, or `O(length)` space and `O(length +/// "nightly" feature is enabled, or `O(length)` space and `O(length /// + amount * log length)` time otherwise. /// /// Panics if `amount > length`. @@ -342,7 +342,7 @@ where } impl Eq for Element {} - #[cfg(feature = "partition_at_index")] + #[cfg(feature = "nightly")] { let mut candidates = Vec::with_capacity(length.as_usize()); for index in 0..length.as_usize() { @@ -370,7 +370,7 @@ where Ok(IndexVec::from(result)) } - #[cfg(not(feature = "partition_at_index"))] + #[cfg(not(feature = "nightly"))] { #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::collections::BinaryHeap; From d73eac541ac1aa754f1de503cd44c60c55e908e1 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Wed, 5 Aug 2020 09:09:46 +0200 Subject: [PATCH 11/11] sample_weighted: Make sure the correct `IndexVec` is generated Also add some tests. --- src/seq/index.rs | 60 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 427e92641be..55053e3505e 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -280,9 +280,10 @@ where F: Fn(usize) -> X, X: Into, { - if length > (::core::u32::MAX as usize) { + if length > (core::u32::MAX as usize) { sample_efraimidis_spirakis(rng, length, weight, amount) } else { + assert!(amount <= core::u32::MAX as usize); let amount = amount as u32; let length = length as u32; sample_efraimidis_spirakis(rng, length, weight, amount) @@ -310,6 +311,7 @@ where F: Fn(usize) -> X, X: Into, N: UInt, + IndexVec: From>, { if amount == N::zero() { return Ok(IndexVec::U32(Vec::new())); @@ -345,14 +347,17 @@ where #[cfg(feature = "nightly")] { let mut candidates = Vec::with_capacity(length.as_usize()); - for index in 0..length.as_usize() { - let weight = weight(index).into(); + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); } let key = rng.gen::().powf(1.0 / weight); - candidates.push(Element { index, key }) + candidates.push(Element { index, key }); + + index += N::one(); } // Partially sort the array to find the `amount` elements with the greatest @@ -362,7 +367,7 @@ where let (_, mid, greater) = candidates.partition_at_index(length.as_usize() - amount.as_usize()); - let mut result = Vec::with_capacity(amount.as_usize()); + let mut result: Vec = Vec::with_capacity(amount.as_usize()); result.push(mid.index); for element in greater { result.push(element.index); @@ -380,17 +385,20 @@ where // Partially sort the array such that the `amount` elements with the largest // keys are first using a binary max heap. let mut candidates = BinaryHeap::with_capacity(length.as_usize()); - for index in 0..length.as_usize() { - let weight = weight(index).into(); - if weight < 0.0 || weight.is_nan() { + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); + if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); } let key = rng.gen::().powf(1.0 / weight); candidates.push(Element { index, key }); + + index += N::one(); } - let mut result = Vec::with_capacity(amount.as_usize()); + let mut result: Vec = Vec::with_capacity(amount.as_usize()); while result.len() < amount.as_usize() { result.push(candidates.pop().unwrap().index); } @@ -462,8 +470,10 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash { +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + + core::hash::Hash + core::ops::AddAssign { fn zero() -> Self; + fn one() -> Self; fn as_usize(self) -> usize; } impl UInt for u32 { @@ -472,6 +482,11 @@ impl UInt for u32 { 0 } + #[inline] + fn one() -> Self { + 1 + } + #[inline] fn as_usize(self) -> usize { self as usize @@ -483,6 +498,11 @@ impl UInt for usize { 0 } + #[inline] + fn one() -> Self { + 1 + } + #[inline] fn as_usize(self) -> usize { self @@ -602,6 +622,26 @@ mod test { assert_eq!(v1, v2); } + #[test] + fn test_sample_weighted() { + let seed_rng = crate::test::rng; + for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); + match v { + IndexVec::U32(mut indices) => { + assert_eq!(indices.len(), amount); + indices.sort(); + indices.dedup(); + assert_eq!(indices.len(), amount); + for &i in &indices { + assert!((i as usize) < len); + } + }, + IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + } + } + #[test] fn value_stability_sample() { let do_test = |length, amount, values: &[u32]| {