From 6044cc85dd3be8cdce0639703c37b9786d711758 Mon Sep 17 00:00:00 2001 From: Paul Dicker Date: Thu, 7 Jun 2018 08:25:21 +0200 Subject: [PATCH] Optimize Bernoulli::new --- benches/misc.rs | 3 +- src/distributions/bernoulli.rs | 51 ++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/benches/misc.rs b/benches/misc.rs index 55a9805fcb6..4e9cbda37ae 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -63,9 +63,8 @@ fn misc_gen_ratio_var(b: &mut Bencher) { #[bench] fn misc_bernoulli_const(b: &mut Bencher) { let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap(); - let d = rand::distributions::Bernoulli::new(0.18); b.iter(|| { - // Can be evaluated at compile time. + let d = rand::distributions::Bernoulli::new(0.18); let mut accum = true; for _ in 0..::RAND_BENCH_N { accum ^= rng.sample(d); diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs index 8bf70c82558..76bf924730d 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distributions/bernoulli.rs @@ -37,6 +37,31 @@ pub struct Bernoulli { p_int: u64, } +// To sample from the Bernoulli distribution we use a method that compares a +// random `u64` value `v < (p * 2^64)`. +// +// If `p == 1.0`, the integer `v` to compare against can not represented as a +// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64). +// Note that value of `p < 1.0` can never result in `u64::MAX`, because an +// `f64` only has 53 bits of precision, and the next largest value of `p` will +// result in `2^64 - 2048`. +// +// Also there is a 100% theoretical concern: if someone consistenly wants to +// generate `true` using the Bernoulli distribution (i.e. by using a probability +// of `1.0`), just using `u64::MAX` is not enough. On average it would return +// false once every 2^64 iterations. Some people apparently care about this +// case. +// +// That is why we special-case `u64::MAX` to always return `true`, without using +// the RNG, and pay the performance price for all uses that *are* reasonable. +// Luckily, if `new()` and `sample` are close, the compiler can optimize out the +// extra check. +const ALWAYS_TRUE: u64 = ::core::u64::MAX; + +// This is just `2.0.powi(64)`, but written this way because it is not available +// in `no_std` mode. +const SCALE: f64 = 2.0 * (1u64 << 63) as f64; + impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. /// @@ -54,18 +79,11 @@ impl Bernoulli { /// 2-64 in `[0, 1]` can be represented as a `f64`.) #[inline] pub fn new(p: f64) -> Bernoulli { - assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0"); - // Technically, this should be 2^64 or `u64::MAX + 1` because we compare - // using `<` when sampling. However, `u64::MAX` rounds to an `f64` - // larger than `u64::MAX` anyway. - const MAX_P_INT: f64 = ::core::u64::MAX as f64; - let p_int = if p < 1.0 { - (p * MAX_P_INT) as u64 - } else { - // Avoid overflow: `MAX_P_INT` cannot be represented as u64. - ::core::u64::MAX - }; - Bernoulli { p_int } + if p < 0.0 || p >= 1.0 { + if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } } + panic!("Bernoulli::new not called with 0.0 <= p <= 1.0"); + } + Bernoulli { p_int: (p * SCALE) as u64 } } /// Construct a new `Bernoulli` with the probability of success of @@ -85,7 +103,6 @@ impl Bernoulli { if numerator == denominator { return Bernoulli { p_int: ::core::u64::MAX } } - const SCALE: f64 = 2.0 * (1u64 << 63) as f64; let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64; Bernoulli { p_int } } @@ -95,11 +112,9 @@ impl Distribution for Bernoulli { #[inline] fn sample(&self, rng: &mut R) -> bool { // Make sure to always return true for p = 1.0. - if self.p_int == ::core::u64::MAX { - return true; - } - let r: u64 = rng.gen(); - r < self.p_int + if self.p_int == ALWAYS_TRUE { return true; } + let v: u64 = rng.gen(); + v < self.p_int } }