diff --git a/src/seq/index.rs b/src/seq/index.rs index 427e92641be..55053e3505e 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -280,9 +280,10 @@ where F: Fn(usize) -> X, X: Into, { - if length > (::core::u32::MAX as usize) { + if length > (core::u32::MAX as usize) { sample_efraimidis_spirakis(rng, length, weight, amount) } else { + assert!(amount <= core::u32::MAX as usize); let amount = amount as u32; let length = length as u32; sample_efraimidis_spirakis(rng, length, weight, amount) @@ -310,6 +311,7 @@ where F: Fn(usize) -> X, X: Into, N: UInt, + IndexVec: From>, { if amount == N::zero() { return Ok(IndexVec::U32(Vec::new())); @@ -345,14 +347,17 @@ where #[cfg(feature = "nightly")] { let mut candidates = Vec::with_capacity(length.as_usize()); - for index in 0..length.as_usize() { - let weight = weight(index).into(); + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); } let key = rng.gen::().powf(1.0 / weight); - candidates.push(Element { index, key }) + candidates.push(Element { index, key }); + + index += N::one(); } // Partially sort the array to find the `amount` elements with the greatest @@ -362,7 +367,7 @@ where let (_, mid, greater) = candidates.partition_at_index(length.as_usize() - amount.as_usize()); - let mut result = Vec::with_capacity(amount.as_usize()); + let mut result: Vec = Vec::with_capacity(amount.as_usize()); result.push(mid.index); for element in greater { result.push(element.index); @@ -380,17 +385,20 @@ where // Partially sort the array such that the `amount` elements with the largest // keys are first using a binary max heap. let mut candidates = BinaryHeap::with_capacity(length.as_usize()); - for index in 0..length.as_usize() { - let weight = weight(index).into(); - if weight < 0.0 || weight.is_nan() { + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); + if !(weight >= 0.) { return Err(WeightedError::InvalidWeight); } let key = rng.gen::().powf(1.0 / weight); candidates.push(Element { index, key }); + + index += N::one(); } - let mut result = Vec::with_capacity(amount.as_usize()); + let mut result: Vec = Vec::with_capacity(amount.as_usize()); while result.len() < amount.as_usize() { result.push(candidates.pop().unwrap().index); } @@ -462,8 +470,10 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash { +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + + core::hash::Hash + core::ops::AddAssign { fn zero() -> Self; + fn one() -> Self; fn as_usize(self) -> usize; } impl UInt for u32 { @@ -472,6 +482,11 @@ impl UInt for u32 { 0 } + #[inline] + fn one() -> Self { + 1 + } + #[inline] fn as_usize(self) -> usize { self as usize @@ -483,6 +498,11 @@ impl UInt for usize { 0 } + #[inline] + fn one() -> Self { + 1 + } + #[inline] fn as_usize(self) -> usize { self @@ -602,6 +622,26 @@ mod test { assert_eq!(v1, v2); } + #[test] + fn test_sample_weighted() { + let seed_rng = crate::test::rng; + for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); + match v { + IndexVec::U32(mut indices) => { + assert_eq!(indices.len(), amount); + indices.sort(); + indices.dedup(); + assert_eq!(indices.len(), amount); + for &i in &indices { + assert!((i as usize) < len); + } + }, + IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + } + } + #[test] fn value_stability_sample() { let do_test = |length, amount, values: &[u32]| {