diff --git a/benches/misc.rs b/benches/misc.rs
index 55a9805fcb6..4e9cbda37ae 100644
--- a/benches/misc.rs
+++ b/benches/misc.rs
@@ -63,9 +63,8 @@ fn misc_gen_ratio_var(b: &mut Bencher) {
#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
- let d = rand::distributions::Bernoulli::new(0.18);
b.iter(|| {
- // Can be evaluated at compile time.
+ let d = rand::distributions::Bernoulli::new(0.18);
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.sample(d);
diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs
index 8bf70c82558..76bf924730d 100644
--- a/src/distributions/bernoulli.rs
+++ b/src/distributions/bernoulli.rs
@@ -37,6 +37,31 @@ pub struct Bernoulli {
p_int: u64,
}
+// To sample from the Bernoulli distribution we use a method that compares a
+// random `u64` value `v < (p * 2^64)`.
+//
+// If `p == 1.0`, the integer `v` to compare against can not represented as a
+// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64).
+// Note that value of `p < 1.0` can never result in `u64::MAX`, because an
+// `f64` only has 53 bits of precision, and the next largest value of `p` will
+// result in `2^64 - 2048`.
+//
+// Also there is a 100% theoretical concern: if someone consistenly wants to
+// generate `true` using the Bernoulli distribution (i.e. by using a probability
+// of `1.0`), just using `u64::MAX` is not enough. On average it would return
+// false once every 2^64 iterations. Some people apparently care about this
+// case.
+//
+// That is why we special-case `u64::MAX` to always return `true`, without using
+// the RNG, and pay the performance price for all uses that *are* reasonable.
+// Luckily, if `new()` and `sample` are close, the compiler can optimize out the
+// extra check.
+const ALWAYS_TRUE: u64 = ::core::u64::MAX;
+
+// This is just `2.0.powi(64)`, but written this way because it is not available
+// in `no_std` mode.
+const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
+
impl Bernoulli {
/// Construct a new `Bernoulli` with the given probability of success `p`.
///
@@ -54,18 +79,11 @@ impl Bernoulli {
/// 2-64 in `[0, 1]` can be represented as a `f64`.)
#[inline]
pub fn new(p: f64) -> Bernoulli {
- assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
- // Technically, this should be 2^64 or `u64::MAX + 1` because we compare
- // using `<` when sampling. However, `u64::MAX` rounds to an `f64`
- // larger than `u64::MAX` anyway.
- const MAX_P_INT: f64 = ::core::u64::MAX as f64;
- let p_int = if p < 1.0 {
- (p * MAX_P_INT) as u64
- } else {
- // Avoid overflow: `MAX_P_INT` cannot be represented as u64.
- ::core::u64::MAX
- };
- Bernoulli { p_int }
+ if p < 0.0 || p >= 1.0 {
+ if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } }
+ panic!("Bernoulli::new not called with 0.0 <= p <= 1.0");
+ }
+ Bernoulli { p_int: (p * SCALE) as u64 }
}
/// Construct a new `Bernoulli` with the probability of success of
@@ -85,7 +103,6 @@ impl Bernoulli {
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
}
- const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
}
@@ -95,11 +112,9 @@ impl Distribution for Bernoulli {
#[inline]
fn sample(&self, rng: &mut R) -> bool {
// Make sure to always return true for p = 1.0.
- if self.p_int == ::core::u64::MAX {
- return true;
- }
- let r: u64 = rng.gen();
- r < self.p_int
+ if self.p_int == ALWAYS_TRUE { return true; }
+ let v: u64 = rng.gen();
+ v < self.p_int
}
}