-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SliceRandom::choose_multiple_weighted, implementing weighted sampling without replacement #976
Changes from all commits
aeeae49
45fa908
efbab50
f255bb9
4501873
4f76aff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ | |
#![doc(test(attr(allow(unused_variables), deny(warnings))))] | ||
#![cfg_attr(not(feature = "std"), no_std)] | ||
#![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))] | ||
#![cfg_attr(all(feature = "partition_at_index", feature = "nightly"), feature(slice_partition_at_index))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This however makes separate use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't understand what the advantage of a separate |
||
#![allow( | ||
clippy::excessive_precision, | ||
clippy::unreadable_literal, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,21 @@ | |
|
||
//! Low-level API for sampling indices | ||
|
||
#[cfg(feature = "alloc")] use core::slice; | ||
#[cfg(feature = "alloc")] | ||
use core::slice; | ||
|
||
#[cfg(all(feature = "alloc", not(feature = "std")))] | ||
use crate::alloc::vec::{self, Vec}; | ||
#[cfg(feature = "std")] use std::vec; | ||
#[cfg(feature = "std")] | ||
use std::vec; | ||
// BTreeMap is not as fast in tests, but better than nothing. | ||
#[cfg(all(feature = "alloc", not(feature = "std")))] | ||
use crate::alloc::collections::BTreeSet; | ||
#[cfg(feature = "std")] use std::collections::HashSet; | ||
#[cfg(feature = "std")] | ||
use std::collections::HashSet; | ||
Comment on lines
-11
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave the reformatting out please. |
||
|
||
#[cfg(feature = "alloc")] | ||
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform}; | ||
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError}; | ||
use crate::Rng; | ||
|
||
/// A vector of indices. | ||
|
@@ -249,6 +252,117 @@ where R: Rng + ?Sized { | |
} | ||
} | ||
|
||
/// Randomly sample exactly `amount` distinct indices from `0..length`, and | ||
/// return them in an arbitrary order (there is no guarantee of shuffling or | ||
/// ordering). The weights are to be provided by the input function `weights`, | ||
/// which will be called once for each index. | ||
/// | ||
/// This method is used internally by the slice sampling methods, but it can | ||
/// sometimes be useful to have the indices themselves so this is provided as | ||
/// an alternative. | ||
/// | ||
/// This implementation uses `O(length + amount)` space and `O(length)` time | ||
/// if the "partition_at_index" feature is enabled, or `O(length)` space and | ||
/// `O(length + amount * log length)` time otherwise. | ||
/// | ||
/// Panics if `amount > length`. | ||
pub fn sample_weighted<R, F, X>( | ||
rng: &mut R, length: usize, weight: F, amount: usize, | ||
) -> Result<IndexVec, WeightedError> | ||
where | ||
R: Rng + ?Sized, | ||
F: Fn(usize) -> X, | ||
X: Into<f64>, | ||
{ | ||
Comment on lines
+269
to
+276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation looks correct, but it may use a lot of memory without good cause. Suggestions: If we know An implementation of the A-Res variant should compare well on performance I think (though still |
||
if amount > length { | ||
panic!("`amount` of samples must be less than or equal to `length`"); | ||
} | ||
|
||
// This implementation uses the algorithm described by Efraimidis and Spirakis | ||
// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 | ||
|
||
struct Element { | ||
index: usize, | ||
key: f64, | ||
} | ||
impl PartialOrd for Element { | ||
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { | ||
self.key | ||
.partial_cmp(&other.key) | ||
.or(Some(core::cmp::Ordering::Less)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why include this line? To avoid panic in |
||
} | ||
} | ||
impl Ord for Element { | ||
fn cmp(&self, other: &Self) -> core::cmp::Ordering { | ||
self.partial_cmp(other).unwrap() // partial_cmp will always produce a value | ||
} | ||
} | ||
impl PartialEq for Element { | ||
fn eq(&self, other: &Self) -> bool { | ||
self.key == other.key | ||
} | ||
} | ||
impl Eq for Element {} | ||
|
||
#[cfg(feature = "partition_at_index")] | ||
{ | ||
if length == 0 { | ||
return Ok(IndexVec::USize(Vec::new())); | ||
} | ||
Comment on lines
+309
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better, if |
||
|
||
let mut candidates = Vec::with_capacity(length); | ||
for index in 0..length { | ||
let weight = weight(index).into(); | ||
if weight < 0.0 || weight.is_nan() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer |
||
return Err(WeightedError::InvalidWeight); | ||
} | ||
|
||
let key = rng.gen::<f64>().powf(1.0 / weight); | ||
candidates.push(Element { index, key }) | ||
} | ||
|
||
// Partially sort the array to find the `amount` elements with the greatest | ||
// keys. Do this by using `partition_at_index` to put the elements with | ||
// the *smallest* keys at the beginning of the list in `O(n)` time, which | ||
// provides equivalent information about the elements with the *greatest* keys. | ||
let (_, mid, greater) = candidates.partition_at_index(length - amount); | ||
|
||
let mut result = Vec::with_capacity(amount); | ||
result.push(mid.index); | ||
for element in greater { | ||
result.push(element.index); | ||
} | ||
Ok(IndexVec::USize(result)) | ||
} | ||
|
||
#[cfg(not(feature = "partition_at_index"))] | ||
{ | ||
#[cfg(all(feature = "alloc", not(feature = "std")))] | ||
use crate::alloc::collections::BinaryHeap; | ||
#[cfg(feature = "std")] | ||
use std::collections::BinaryHeap; | ||
|
||
// Partially sort the array such that the `amount` elements with the largest | ||
// keys are first using a binary max heap. | ||
let mut candidates = BinaryHeap::with_capacity(length); | ||
for index in 0..length { | ||
let weight = weight(index).into(); | ||
if weight < 0.0 || weight.is_nan() { | ||
return Err(WeightedError::InvalidWeight); | ||
} | ||
|
||
let key = rng.gen::<f64>().powf(1.0 / weight); | ||
candidates.push(Element { index, key }); | ||
} | ||
|
||
let mut result = Vec::with_capacity(amount); | ||
while result.len() < amount { | ||
result.push(candidates.pop().unwrap().index); | ||
} | ||
Ok(IndexVec::USize(result)) | ||
} | ||
} | ||
|
||
/// Randomly sample exactly `amount` indices from `0..length`, using Floyd's | ||
/// combination algorithm. | ||
/// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A simpler approach might be to just use the
nightly
feature instead of introducing a new one. This would have the advantage that we don't have to worry about dropping the feature in the future without breaking code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning here is that this allows leaving the feature flag in place even once "slice_partition_at_index" stabilizes, which would allow supporting older rustc versions. I also see your point about dropping the feature being a breaking change, though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I think we will likely just raise the minimum Rust version. @dhardy What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it is not possible to reliably support old nightlies, and not very useful either. So removing a feature flag only usable on nightly compilers once that feature has stabilised is not an issue.