Skip to content

Commit

Permalink
UniformFloat: allow inclusion of high in all cases (#1462)
Browse files Browse the repository at this point in the history
Fix #1299 by removing logic specific to ensuring that we emulate a
closed range by excluding `high` from the result.
  • Loading branch information
dhardy committed Jul 16, 2024
1 parent 2584f48 commit 1e381d1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 129 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
- Move all benchmarks to new `benches` crate (#1439)
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
- Enable feature `small_rng` by default (#1455)
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)

## [0.9.0-alpha.1] - 2024-03-18
- Add the `Slice::num_choices` method to the Slice distribution (#1402)
Expand Down
165 changes: 58 additions & 107 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
//! Those methods should include an assertion to check the range is valid (i.e.
//! `low < high`). The example below merely wraps another back-end.
//!
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive`
//! functions use arguments of
//! type `SampleBorrow<X>` to support passing in values by reference or
//! by value. In the implementation of these functions, you can choose to
//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose
Expand Down Expand Up @@ -207,6 +208,11 @@ impl<X: SampleUniform> Uniform<X> {
/// Create a new `Uniform` instance, which samples uniformly from the half
/// open range `[low, high)` (excluding `high`).
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is
/// non-finite. In release mode, only the range is checked.
pub fn new<B1, B2>(low: B1, high: B2) -> Result<Uniform<X>, Error>
Expand Down Expand Up @@ -265,6 +271,11 @@ pub trait UniformSampler: Sized {

/// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`.
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// Usually users should not call this directly but prefer to use
/// [`Uniform::new`].
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
Expand All @@ -287,6 +298,11 @@ pub trait UniformSampler: Sized {
/// Sample a single value uniformly from a range with inclusive lower bound
/// and exclusive upper bound `[low, high)`.
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// By default this is implemented using
/// `UniformSampler::new(low, high).sample(rng)`. However, for some types
/// more optimal implementations for single usage may be provided via this
Expand Down Expand Up @@ -908,6 +924,33 @@ pub struct UniformFloat<X> {

macro_rules! uniform_float_impl {
($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
$(#[cfg($meta)])?
impl UniformFloat<$ty> {
/// Construct, reducing `scale` as required to ensure that rounding
/// can never yield values greater than `high`.
///
/// Note: though it may be tempting to use a variant of this method
/// to ensure that samples from `[low, high)` are always strictly
/// less than `high`, this approach may be very slow where
/// `scale.abs()` is much smaller than `high.abs()`
/// (example: `low=0.99999999997819644, high=1.`).
fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);

loop {
let mask = (scale * max_rand + low).gt_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

UniformFloat { low, scale }
}
}

$(#[cfg($meta)])?
impl SampleUniform for $ty {
type Sampler = UniformFloat<$ty>;
Expand All @@ -931,26 +974,13 @@ macro_rules! uniform_float_impl {
if !(low.all_lt(high)) {
return Err(Error::EmptyRange);
}
let max_rand = <$ty>::splat(
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

let mut scale = high - low;
let scale = high - low;
if !(scale.all_finite()) {
return Err(Error::NonFinite);
}

loop {
let mask = (scale * max_rand + low).ge_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

Ok(UniformFloat { low, scale })
Ok(Self::new_bounded(low, high, scale))
}

fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
Expand All @@ -967,26 +997,14 @@ macro_rules! uniform_float_impl {
if !low.all_le(high) {
return Err(Error::EmptyRange);
}
let max_rand = <$ty>::splat(
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

let mut scale = (high - low) / max_rand;
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
let scale = (high - low) / max_rand;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

loop {
let mask = (scale * max_rand + low).gt_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

Ok(UniformFloat { low, scale })
Ok(Self::new_bounded(low, high, scale))
}

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Expand All @@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl {
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
#[cfg(debug_assertions)]
if !low.all_finite() || !high.all_finite() {
return Err(Error::NonFinite);
}
if !low.all_lt(high) {
return Err(Error::EmptyRange);
}
let mut scale = high - low;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

loop {
// Generate a value in the range [1, 2)
let value1_2 =
(rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);

// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
let value0_1 = value1_2 - <$ty>::splat(1.0);

// Doing multiply before addition allows some architectures
// to use a single instruction.
let res = value0_1 * scale + low;

debug_assert!(low.all_le(res) || !scale.all_finite());
if res.all_lt(high) {
return Ok(res);
}

// This handles a number of edge cases.
// * `low` or `high` is NaN. In this case `scale` and
// `res` are going to end up as NaN.
// * `low` is negative infinity and `high` is finite.
// `scale` is going to be infinite and `res` will be
// NaN.
// * `high` is positive infinity and `low` is finite.
// `scale` is going to be infinite and `res` will
// be infinite or NaN (if value0_1 is 0).
// * `low` is negative infinity and `high` is positive
// infinity. `scale` will be infinite and `res` will
// be NaN.
// * `low` and `high` are finite, but `high - low`
// overflows to infinite. `scale` will be infinite
// and `res` will be infinite or NaN (if value0_1 is 0).
// So if `high` or `low` are non-finite, we are guaranteed
// to fail the `res < high` check above and end up here.
//
// While we technically should check for non-finite `low`
// and `high` before entering the loop, by doing the checks
// here instead, we allow the common case to avoid these
// checks. But we are still guaranteed that if `low` or
// `high` are non-finite we'll end up here and can do the
// appropriate checks.
//
// Likewise, `high - low` overflowing to infinity is also
// rare, so handle it here after the common case.
let mask = !scale.finite_mask();
if mask.any() {
if !(low.all_finite() && high.all_finite()) {
return Err(Error::NonFinite);
}
scale = scale.decrease_masked(mask);
}
}
Self::sample_single_inclusive(low_b, high_b, rng)
}

#[inline]
Expand Down Expand Up @@ -1465,14 +1418,14 @@ mod tests {
let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
for _ in 0..100 {
let v = rng.sample(my_uniform).extract(lane);
assert!(low_scalar <= v && v < high_scalar);
assert!(low_scalar <= v && v <= high_scalar);
let v = rng.sample(my_incl_uniform).extract(lane);
assert!(low_scalar <= v && v <= high_scalar);
let v =
<$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
.unwrap()
.extract(lane);
assert!(low_scalar <= v && v < high_scalar);
assert!(low_scalar <= v && v <= high_scalar);
let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
low, high, &mut rng,
)
Expand Down Expand Up @@ -1510,12 +1463,12 @@ mod tests {
low_scalar
);

assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar);
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
// sample_single cannot cope with max_rng:
// assert!(<$ty as SampleUniform>::Sampler
// ::sample_single(low, high, &mut max_rng).unwrap()
// .extract(lane) < high_scalar);
// .extract(lane) <= high_scalar);
assert!(
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
low,
Expand Down Expand Up @@ -1543,7 +1496,7 @@ mod tests {
)
.unwrap()
.extract(lane)
< high_scalar
<= high_scalar
);
}
}
Expand Down Expand Up @@ -1590,10 +1543,9 @@ mod tests {
#[cfg(all(feature = "std", panic = "unwind"))]
fn test_float_assertions() {
use super::SampleUniform;
use std::panic::catch_unwind;
fn range<T: SampleUniform>(low: T, high: T) {
fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
let mut rng = crate::test::rng(253);
T::Sampler::sample_single(low, high, &mut rng).unwrap();
T::Sampler::sample_single(low, high, &mut rng)
}

macro_rules! t {
Expand All @@ -1616,10 +1568,9 @@ mod tests {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
assert!(catch_unwind(|| range(low, high)).is_err());
assert!(range(low, high).is_err());
assert!(Uniform::new(low, high).is_err());
assert!(Uniform::new_inclusive(low, high).is_err());
assert!(catch_unwind(|| range(low, low)).is_err());
assert!(Uniform::new(low, low).is_err());
}
}
Expand Down
22 changes: 0 additions & 22 deletions src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils {
fn all_finite(self) -> bool;

type Mask;
fn finite_mask(self) -> Self::Mask;
fn gt_mask(self, other: Self) -> Self::Mask;
fn ge_mask(self, other: Self) -> Self::Mask;

// Decrease all lanes where the mask is `true` to the next lower value
// representable by the floating-point type. At least one of the lanes
Expand Down Expand Up @@ -292,21 +290,11 @@ macro_rules! scalar_float_impl {
self.is_finite()
}

#[inline(always)]
fn finite_mask(self) -> Self::Mask {
self.is_finite()
}

#[inline(always)]
fn gt_mask(self, other: Self) -> Self::Mask {
self > other
}

#[inline(always)]
fn ge_mask(self, other: Self) -> Self::Mask {
self >= other
}

#[inline(always)]
fn decrease_masked(self, mask: Self::Mask) -> Self {
debug_assert!(mask, "At least one lane must be set");
Expand Down Expand Up @@ -368,21 +356,11 @@ macro_rules! simd_impl {
self.is_finite().all()
}

#[inline(always)]
fn finite_mask(self) -> Self::Mask {
self.is_finite()
}

#[inline(always)]
fn gt_mask(self, other: Self) -> Self::Mask {
self.simd_gt(other)
}

#[inline(always)]
fn ge_mask(self, other: Self) -> Self::Mask {
self.simd_ge(other)
}

#[inline(always)]
fn decrease_masked(self, mask: Self::Mask) -> Self {
// Casting a mask into ints will produce all bits set for
Expand Down

0 comments on commit 1e381d1

Please sign in to comment.