From 1e381d13ee060f678c9a4ca54124ab8af30b4402 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 16 Jul 2024 11:32:06 +0100 Subject: [PATCH] UniformFloat: allow inclusion of high in all cases (#1462) Fix #1299 by removing logic specific to ensuring that we emulate a closed range by excluding `high` from the result. --- CHANGELOG.md | 1 + src/distributions/uniform.rs | 165 ++++++++++++----------------------- src/distributions/utils.rs | 22 ----- 3 files changed, 59 insertions(+), 129 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18ce72a533..39e3c5c727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. - Move all benchmarks to new `benches` crate (#1439) - Annotate panicking methods with `#[track_caller]` (#1442, #1447) - Enable feature `small_rng` by default (#1455) +- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462) ## [0.9.0-alpha.1] - 2024-03-18 - Add the `Slice::num_choices` method to the Slice distribution (#1402) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 34a6b252f4..5540b74e46 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -51,7 +51,8 @@ //! Those methods should include an assertion to check the range is valid (i.e. //! `low < high`). The example below merely wraps another back-end. //! -//! The `new`, `new_inclusive` and `sample_single` functions use arguments of +//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive` +//! functions use arguments of //! type `SampleBorrow` to support passing in values by reference or //! by value. In the implementation of these functions, you can choose to //! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose @@ -207,6 +208,11 @@ impl Uniform { /// Create a new `Uniform` instance, which samples uniformly from the half /// open range `[low, high)` (excluding `high`). /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is /// non-finite. In release mode, only the range is checked. pub fn new(low: B1, high: B2) -> Result, Error> @@ -265,6 +271,11 @@ pub trait UniformSampler: Sized { /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// Usually users should not call this directly but prefer to use /// [`Uniform::new`]. fn new(low: B1, high: B2) -> Result @@ -287,6 +298,11 @@ pub trait UniformSampler: Sized { /// Sample a single value uniformly from a range with inclusive lower bound /// and exclusive upper bound `[low, high)`. /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// By default this is implemented using /// `UniformSampler::new(low, high).sample(rng)`. However, for some types /// more optimal implementations for single usage may be provided via this @@ -908,6 +924,33 @@ pub struct UniformFloat { macro_rules! uniform_float_impl { ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + $(#[cfg($meta)])? + impl UniformFloat<$ty> { + /// Construct, reducing `scale` as required to ensure that rounding + /// can never yield values greater than `high`. + /// + /// Note: though it may be tempting to use a variant of this method + /// to ensure that samples from `[low, high)` are always strictly + /// less than `high`, this approach may be very slow where + /// `scale.abs()` is much smaller than `high.abs()` + /// (example: `low=0.99999999997819644, high=1.`). + fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self { + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if !mask.any() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + } + $(#[cfg($meta)])? impl SampleUniform for $ty { type Sampler = UniformFloat<$ty>; @@ -931,26 +974,13 @@ macro_rules! uniform_float_impl { if !(low.all_lt(high)) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat( - ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - let mut scale = high - low; + let scale = high - low; if !(scale.all_finite()) { return Err(Error::NonFinite); } - loop { - let mask = (scale * max_rand + low).ge_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) + Ok(Self::new_bounded(low, high, scale)) } fn new_inclusive(low_b: B1, high_b: B2) -> Result @@ -967,26 +997,14 @@ macro_rules! uniform_float_impl { if !low.all_le(high) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat( - ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - let mut scale = (high - low) / max_rand; + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + let scale = (high - low) / max_rand; if !scale.all_finite() { return Err(Error::NonFinite); } - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) + Ok(Self::new_bounded(low, high, scale)) } fn sample(&self, rng: &mut R) -> Self::X { @@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl { B1: SampleBorrow + Sized, B2: SampleBorrow + Sized, { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !low.all_finite() || !high.all_finite() { - return Err(Error::NonFinite); - } - if !low.all_lt(high) { - return Err(Error::EmptyRange); - } - let mut scale = high - low; - if !scale.all_finite() { - return Err(Error::NonFinite); - } - - loop { - // Generate a value in the range [1, 2) - let value1_2 = - (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); - - // Get a value in the range [0, 1) to avoid overflow when multiplying by scale - let value0_1 = value1_2 - <$ty>::splat(1.0); - - // Doing multiply before addition allows some architectures - // to use a single instruction. - let res = value0_1 * scale + low; - - debug_assert!(low.all_le(res) || !scale.all_finite()); - if res.all_lt(high) { - return Ok(res); - } - - // This handles a number of edge cases. - // * `low` or `high` is NaN. In this case `scale` and - // `res` are going to end up as NaN. - // * `low` is negative infinity and `high` is finite. - // `scale` is going to be infinite and `res` will be - // NaN. - // * `high` is positive infinity and `low` is finite. - // `scale` is going to be infinite and `res` will - // be infinite or NaN (if value0_1 is 0). - // * `low` is negative infinity and `high` is positive - // infinity. `scale` will be infinite and `res` will - // be NaN. - // * `low` and `high` are finite, but `high - low` - // overflows to infinite. `scale` will be infinite - // and `res` will be infinite or NaN (if value0_1 is 0). - // So if `high` or `low` are non-finite, we are guaranteed - // to fail the `res < high` check above and end up here. - // - // While we technically should check for non-finite `low` - // and `high` before entering the loop, by doing the checks - // here instead, we allow the common case to avoid these - // checks. But we are still guaranteed that if `low` or - // `high` are non-finite we'll end up here and can do the - // appropriate checks. - // - // Likewise, `high - low` overflowing to infinity is also - // rare, so handle it here after the common case. - let mask = !scale.finite_mask(); - if mask.any() { - if !(low.all_finite() && high.all_finite()) { - return Err(Error::NonFinite); - } - scale = scale.decrease_masked(mask); - } - } + Self::sample_single_inclusive(low_b, high_b, rng) } #[inline] @@ -1465,14 +1418,14 @@ mod tests { let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); for _ in 0..100 { let v = rng.sample(my_uniform).extract(lane); - assert!(low_scalar <= v && v < high_scalar); + assert!(low_scalar <= v && v <= high_scalar); let v = rng.sample(my_incl_uniform).extract(lane); assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) .unwrap() .extract(lane); - assert!(low_scalar <= v && v < high_scalar); + assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( low, high, &mut rng, ) @@ -1510,12 +1463,12 @@ mod tests { low_scalar ); - assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); + assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar); assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); // sample_single cannot cope with max_rng: // assert!(<$ty as SampleUniform>::Sampler // ::sample_single(low, high, &mut max_rng).unwrap() - // .extract(lane) < high_scalar); + // .extract(lane) <= high_scalar); assert!( <$ty as SampleUniform>::Sampler::sample_single_inclusive( low, @@ -1543,7 +1496,7 @@ mod tests { ) .unwrap() .extract(lane) - < high_scalar + <= high_scalar ); } } @@ -1590,10 +1543,9 @@ mod tests { #[cfg(all(feature = "std", panic = "unwind"))] fn test_float_assertions() { use super::SampleUniform; - use std::panic::catch_unwind; - fn range(low: T, high: T) { + fn range(low: T, high: T) -> Result { let mut rng = crate::test::rng(253); - T::Sampler::sample_single(low, high, &mut rng).unwrap(); + T::Sampler::sample_single(low, high, &mut rng) } macro_rules! t { @@ -1616,10 +1568,9 @@ mod tests { for lane in 0..<$ty>::LEN { let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - assert!(catch_unwind(|| range(low, high)).is_err()); + assert!(range(low, high).is_err()); assert!(Uniform::new(low, high).is_err()); assert!(Uniform::new_inclusive(low, high).is_err()); - assert!(catch_unwind(|| range(low, low)).is_err()); assert!(Uniform::new(low, low).is_err()); } } diff --git a/src/distributions/utils.rs b/src/distributions/utils.rs index aee92b6790..7e84665ec4 100644 --- a/src/distributions/utils.rs +++ b/src/distributions/utils.rs @@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils { fn all_finite(self) -> bool; type Mask; - fn finite_mask(self) -> Self::Mask; fn gt_mask(self, other: Self) -> Self::Mask; - fn ge_mask(self, other: Self) -> Self::Mask; // Decrease all lanes where the mask is `true` to the next lower value // representable by the floating-point type. At least one of the lanes @@ -292,21 +290,11 @@ macro_rules! scalar_float_impl { self.is_finite() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self > other } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self >= other - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { debug_assert!(mask, "At least one lane must be set"); @@ -368,21 +356,11 @@ macro_rules! simd_impl { self.is_finite().all() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self.simd_gt(other) } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self.simd_ge(other) - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { // Casting a mask into ints will produce all bits set for