From fa1f9320564d7407cbf806cb4ccfd535b187e767 Mon Sep 17 00:00:00 2001 From: Jonas Sicking Date: Fri, 29 Jun 2018 00:02:33 -0700 Subject: [PATCH] Improve tests, fix comments, and simplify HighPrecision01 code a little bit --- src/distributions/float.rs | 362 ++++++++++++++++++++++--------------- 1 file changed, 217 insertions(+), 145 deletions(-) diff --git a/src/distributions/float.rs b/src/distributions/float.rs index dfd4770da2..821d72f8c6 100644 --- a/src/distributions/float.rs +++ b/src/distributions/float.rs @@ -72,9 +72,35 @@ pub struct OpenClosed01; pub struct Open01; /// A distribution to do high precision sampling of floating point numbers -/// uniformly in a given range. This is similar to Uniform, but samples -/// numbers with the full precision of the floating-point type used, at the -/// cost of slower performance. +/// uniformly in the half-open interval `[low, high)`. This is similar to +/// [`Uniform`], but generates numbers with as much precision as the +/// floating-point type can represent, including sub-normals, at the cost of +/// slower performance. +/// +/// This distribution effectively picks a point in the continous range between +/// `low` and `high`. It then rounds *down* to the nearest value which is +/// representable as an `f32` or `f64` and returns this value. +/// +/// This means that the probability of a number in the range [a, b) being +/// returned is exactly `(b - a) / (high - low)` (where `a` and `b` are in the +/// `[low, high]` range. The chances of getting exactly the value `x` is +/// `x' - x` where `x'` is the smallest value larger than `x` which is +/// representable as an `f32`/`f64`. +/// +/// Due to the extra logic there is significant performance overhead relative +/// to [`Uniform`]. +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::HighPrecision; +/// +/// let mut rng = thread_rng(); +/// let val = HighPrecision::new(-0.1f32, 20.0).sample(&mut rng); +/// println!("f32 from [-0.1, 20.0): {}", val); +/// ``` +/// +/// [`Uniform`]: struct.Uniform.html #[derive(Clone, Copy, Debug)] pub struct HighPrecision where F: HPFloatHelper { low_as_int: F::SignedInt, @@ -97,30 +123,47 @@ impl HighPrecision { } } -/// Generate a floating point number in the half-open interval `[0, 1)` with a -/// uniform distribution, with as much precision as the floating-point type -/// can represent, including sub-normals. +/// A distribution to do high precision sampling of floating point numbers +/// uniformly in the half-open interval `[0, 1)`. This is similar to +/// [`Standard`], but generates numbers with as much precision as the +/// floating-point type can represent, including sub-normals, at the cost of +/// slower performance. +/// +/// This distribution effectively picks a point in the continous range between +/// `0` and `1`. It then rounds *down* to the nearest value which is +/// representable as an `f32` or `f64` and returns this value. /// -/// Technically 0 is representable, but the probability of occurrence is -/// remote (1 in 2^149 for `f32` or 1 in 2^1074 for `f64`). +/// This means that the probability of a number in the range [a, b) being +/// returned is exactly `b - a` (where `a` and `b` are in the `[0, 1]` range. +/// The chances of getting exactly the value `x` is `x' - x` where `x'` is the +/// smallest value larger than `x` which is representable as an `f32`/`f64`. /// -/// This is different from `Uniform` in that it uses as many random bits as -/// required to get high precision close to 0. Normally only a single call to -/// the source RNG is required (32 bits for `f32` or 64 bits for `f64`); 1 in -/// 2^9 (`f32`) or 2^12 (`f64`) samples need an extra call; of these 1 in 2^32 -/// or 1 in 2^64 require a third call, etc.; i.e. even for `f32` a third call is -/// almost impossible to observe with an unbiased RNG. Due to the extra logic -/// there is some performance overhead relative to `Uniform`; this is more -/// significant for `f32` than for `f64`. +/// So while 0 technically can be returned, the chance of getting exactly 0 is +/// extremely remote, since `f32`/`f64` is able to represent very small values. +/// The chance of getting exactly 0 is 1 in 2^149 for `f32` and 1 in 2^1074 for +/// `f64`. +/// +/// Due to the extra logic there is some performance overhead relative to +/// [`Standard`]; this is more significant for `f32` than for `f64`. +/// +/// `HighPrecision01` uses as many random bits as required to use the full +/// precision of `f32`/`f64`. Normally only a single call to the source RNG is +/// required (32 bits for `f32` or 64 bits for `f64`); 1 in 2^9 (`f32`) or 2^12 +/// (`f64`) samples need an extra call; of these 1 in 2^32 or 1 in 2^64 require +/// a third call, etc.; i.e. even for `f32` a third call is almost impossible +/// to observe with an unbiased RNG. /// /// # Example -/// ```rust -/// use rand::{NewRng, SmallRng, Rng}; +/// ``` +/// use rand::prelude::*; /// use rand::distributions::HighPrecision01; /// -/// let val: f32 = SmallRng::new().sample(HighPrecision01); +/// let mut rng = thread_rng(); +/// let val: f32 = HighPrecision01.sample(&mut rng); /// println!("f32 from [0,1): {}", val); /// ``` +/// +/// [`Standard`]: struct.Standard.html #[derive(Clone, Copy, Debug)] pub struct HighPrecision01; @@ -256,15 +299,22 @@ macro_rules! high_precision_float_impls { fn sample(&self, rng: &mut R) -> $ty { // Unusual case. Separate function to allow inlining of rest. #[inline(never)] - fn fallback(mut exp: i32, fraction: $uty, rng: &mut R) -> $ty { + fn fallback(fraction: $uty, rng: &mut R) -> $ty { // Performance impact of code here is negligible. - let bits = rng.gen::<$uty>(); - exp += bits.trailing_zeros() as i32; - // If RNG were guaranteed unbiased we could skip the - // check against exp; unfortunately it may be. - // Worst case ("zeros" RNG) has recursion depth 16. - if bits == 0 && exp < $exponent_bias { - return fallback(exp, fraction, rng); + + let size_bits = (mem::size_of::<$ty>() * 8) as i32; + let mut exp = (size_bits - $fraction_bits as i32) + 1; + + loop { + let bits = rng.gen::<$uty>(); + exp += bits.trailing_zeros() as i32; + // The chance of hitting $exponent_bias here is effectively + // zero with any decent RNG, since it requires generating + // very many consecutive 0s. But testing code will hit this + // edge case. + if exp >= $exponent_bias || bits != 0 { + break; + } } exp = cmp::min(exp, $exponent_bias); fraction.into_float_with_exponent(-exp) @@ -276,10 +326,7 @@ macro_rules! high_precision_float_impls { let fraction = value & fraction_mask; let remaining = value >> $fraction_bits; if remaining == 0 { - // exp is compile-time constant so this reduces to a function call: - let size_bits = (mem::size_of::<$ty>() * 8) as i32; - let exp = (size_bits - $fraction_bits as i32) + 1; - return fallback(exp, fraction, rng); + return fallback(fraction, rng); } // Usual case: exponent from -1 to -9 (f32) or -12 (f64) @@ -545,107 +592,111 @@ mod tests { let mut r = ::test::rng(601); macro_rules! float_test { - ($ty:ty, $uty:ty, $ity:ty, $extra:expr, $test_vals:expr) => { + ($ty:ty, $uty:ty, $ity:ty, $extra:expr, $test_vals:expr) => {{ // Create a closure to make loop labels local - (|| { - let mut vals: Vec<$ty> = - $test_vals.iter().cloned() - .flat_map(|x| $extra.iter().map(move |y| x + y)) - .map(|x| <$ty>::from_bits(x as $uty)) - .flat_map(|x| vec![x, -x].into_iter()) - .filter(|x| x.is_finite()) - .collect(); - vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); - vals.dedup(); - - for a in vals.iter().cloned() { - for b in vals.iter().cloned().filter(|&b| b > a) { - fn to_signed_bits(val: $ty) -> $ity { - if val >= 0.0 { - val.to_bits() as $ity - } else { - -((-val).to_bits() as $ity) - } + let mut vals: Vec<$ty> = + $test_vals.iter().cloned() + .flat_map(|x| $extra.iter().map(move |y| x + y)) + .map(|x| <$ty>::from_bits(x as $uty)) + .flat_map(|x| vec![x, -x].into_iter()) + .filter(|x| x.is_finite()) + .collect(); + vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + vals.dedup(); + + for a in vals.iter().cloned() { + for b in vals.iter().cloned().filter(|&b| b > a) { + fn to_signed_bits(val: $ty) -> $ity { + if val >= 0.0 { + val.to_bits() as $ity + } else { + -((-val).to_bits() as $ity) } - fn from_signed_bits(val: $ity) -> $ty { - if val >= 0 { - <$ty>::from_bits(val as $uty) - } else { - -<$ty>::from_bits(-val as $uty) - } + } + fn from_signed_bits(val: $ity) -> $ty { + if val >= 0 { + <$ty>::from_bits(val as $uty) + } else { + -<$ty>::from_bits(-val as $uty) } + } - let hp = HighPrecision::new(a, b); - let a_bits = to_signed_bits(a); - let b_bits = to_signed_bits(b); - - const N_RUNS: usize = 10; - const N_REPS_PER_RUN: usize = 1000; - - if (b_bits.wrapping_sub(a_bits) as $uty) < 100 { - // If a and b are "close enough", we can verify the full distribution - let mut counts = Vec::::with_capacity((b_bits - a_bits) as usize); - counts.resize((b_bits - a_bits) as usize, 0); - 'test_loop_exact: for test_run in 1..(N_RUNS+1) { - for _ in 0..N_REPS_PER_RUN { - let res = hp.sample(&mut r); - counts[(to_signed_bits(res) - a_bits) as usize] += 1; - } - for (count, i) in counts.iter().zip(0 as $ity..) { - let expected = (test_run * N_REPS_PER_RUN) as $ty * - ((from_signed_bits(a_bits + i + 1) - - from_signed_bits(a_bits + i)) / (b - a)); - let err = (*count as $ty - expected) / expected; - if err.abs() > 0.2 { - if test_run < N_RUNS { - continue 'test_loop_exact; - } - panic!(format!("Failed {}-bit exact test: a: 0x{:x}, b: 0x{:x}, err: {:.2}", - ::core::mem::size_of::<$ty>() * 8, - a.to_bits(), - b.to_bits(), - err.abs())); - } + let hp = HighPrecision::new(a, b); + let a_bits = to_signed_bits(a); + let b_bits = to_signed_bits(b); + + const N_RUNS: usize = 10; + const N_REPS_PER_RUN: usize = 1000; + + if (b_bits.wrapping_sub(a_bits) as $uty) < 100 { + // If a and b are "close enough", we can verify the full distribution + let mut counts = Vec::::with_capacity((b_bits - a_bits) as usize); + counts.resize((b_bits - a_bits) as usize, 0); + for test_run in 1..(N_RUNS+1) { + for _ in 0..N_REPS_PER_RUN { + let res = hp.sample(&mut r); + counts[(to_signed_bits(res) - a_bits) as usize] += 1; + } + let mut success = true; + for (count, i) in counts.iter().zip(0 as $ity..) { + let expected = (test_run * N_REPS_PER_RUN) as $ty * + ((from_signed_bits(a_bits + i + 1) - + from_signed_bits(a_bits + i)) / (b - a)); + let err = (*count as $ty - expected) / expected; + if err.abs() > 0.2 { + success = false; + assert!(test_run != N_RUNS, + format!("Failed {}-bit exact test: a: 0x{:x}, b: 0x{:x}, err: {:.2}", + ::core::mem::size_of::<$ty>() * 8, + a.to_bits(), + b.to_bits(), + err.abs())); } } + if success { + break; + } + } + } else { + // Otherwise divide range into 10 sections + let step = if (b - a).is_finite() { + (b - a) / 10.0 } else { - // Otherwise divide range into 10 sections - let step = if (b - a).is_finite() { - (b - a) / 10.0 - } else { - b / 10.0 - a / 10.0 - }; - assert!(step.is_finite()); - let mut counts = Vec::::with_capacity(10); - counts.resize(10, 0); - - 'test_loop_rough: for test_run in 1..(N_RUNS+1) { - for _ in 0..N_REPS_PER_RUN { - let res = hp.sample(&mut r); - assert!(a <= res && res < b); - let index = (res / step - a / step) as usize; - counts[::core::cmp::min(index, 9)] += 1; - } - for count in &counts { - let expected = (test_run * N_REPS_PER_RUN) as $ty / 10.0; - let err = (*count as $ty - expected) / expected; - if err.abs() > 0.2 { - if test_run < N_RUNS { - continue 'test_loop_rough; - } - panic!(format!("Failed {}-bit rough test: a: 0x{:x}, b: 0x{:x}, err: {:.2}", - ::core::mem::size_of::<$ty>() * 8, - a.to_bits(), - b.to_bits(), - err.abs())); - } + b / 10.0 - a / 10.0 + }; + assert!(step.is_finite()); + let mut counts = Vec::::with_capacity(10); + counts.resize(10, 0); + + for test_run in 1..(N_RUNS+1) { + for _ in 0..N_REPS_PER_RUN { + let res = hp.sample(&mut r); + assert!(a <= res && res < b); + let index = (res / step - a / step) as usize; + counts[::core::cmp::min(index, 9)] += 1; + } + let mut success = true; + for count in &counts { + let expected = (test_run * N_REPS_PER_RUN) as $ty / 10.0; + let err = (*count as $ty - expected) / expected; + if err.abs() > 0.2 { + success = false; + assert!(test_run != N_RUNS, + format!("Failed {}-bit rough test: a: 0x{:x}, b: 0x{:x}, err: {:.2}", + ::core::mem::size_of::<$ty>() * 8, + a.to_bits(), + b.to_bits(), + err.abs())); } } + if success { + break; + } } } } - })() - } + } + }} } const SLOW_TESTS: bool = false; @@ -726,34 +777,55 @@ mod tests { assert_eq!(zeros.sample::(HighPrecision01), 0.0); let mut ones = StepRng::new(0xffff_ffff_ffff_ffff, 0); - assert_eq!(ones.sample::(HighPrecision01), 0.99999994); - assert_eq!(ones.sample::(HighPrecision01), 0.9999999999999999); + assert_eq!(ones.sample::(HighPrecision01).to_bits(), (1.0f32).to_bits() - 1); + assert_eq!(ones.sample::(HighPrecision01).to_bits(), (1.0f64).to_bits() - 1); } - #[cfg(feature="std")] mod mean { - use Rng; - use distributions::{Standard, HighPrecision01}; - - macro_rules! test_mean { - ($name:ident, $ty:ty, $distr:expr) => { - #[test] - fn $name() { - // TODO: no need to &mut here: - let mut r = ::test::rng(602); - let mut total: $ty = 0.0; - const N: u32 = 1_000_000; - for _ in 0..N { - total += r.sample::<$ty, _>($distr); - } - let avg = total / (N as $ty); - //println!("average over {} samples: {}", N, avg); - assert!(0.499 < avg && avg < 0.501); + #[test] + fn test_distribution() { + + const N_SEGMENTS: usize = 10; + const N_REPS: usize = 2000; + + let mut r = ::test::rng(602); + + macro_rules! impl_test { + ($ty:ty, $dist:expr) => {{ + let dist = $dist; + let mut counts = [(0i32, 0.0 as $ty); N_SEGMENTS]; + let mut total_sum = 0.0 as $ty; + + for _ in 0..N_REPS { + let res: $ty = dist.sample(&mut r); + assert!(0.0 <= res && res < 1.0); + let index = (res * N_SEGMENTS as $ty) as usize; + counts[index].0 += 1; + counts[index].1 += res; + total_sum += res; + } + for (i, &(count, sum)) in counts.iter().enumerate() { + let count_expected = N_REPS as f32 / N_SEGMENTS as f32; + let count_err = (count as f32 - count_expected) / count_expected; + assert!(count_err.abs() < 0.2); + + let sum_expected = (i as $ty * 0.1 + 0.05) * count as $ty; + let sum_err = (sum - sum_expected) / sum_expected; + assert!(sum_err.abs() < 0.2); + } + + let total_expected = 0.5 * N_REPS as $ty; + let total_err = (total_sum - total_expected) / total_expected; + assert!(total_err.abs() < 0.05); + }} } - } } - test_mean!(test_mean_f32, f32, Standard); - test_mean!(test_mean_f64, f64, Standard); - test_mean!(test_mean_high_f32, f32, HighPrecision01); - test_mean!(test_mean_high_f64, f64, HighPrecision01); + impl_test!(f32, Standard); + impl_test!(f64, Standard); + impl_test!(f32, HighPrecision01); + impl_test!(f64, HighPrecision01); + impl_test!(f32, OpenClosed01); + impl_test!(f64, OpenClosed01); + impl_test!(f32, Open01); + impl_test!(f64, Open01); } }