diff --git a/benches/misc.rs b/benches/misc.rs index d3c60c3e950..2a2dfe7faec 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -36,6 +36,30 @@ fn misc_gen_bool_var(b: &mut Bencher) { }) } +#[bench] +fn misc_gen_ratio_const(b: &mut Bencher) { + let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap(); + b.iter(|| { + let mut accum = true; + for _ in 0..::RAND_BENCH_N { + accum ^= rng.gen_ratio(2, 3); + } + accum + }) +} + +#[bench] +fn misc_gen_ratio_var(b: &mut Bencher) { + let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap(); + b.iter(|| { + let mut accum = true; + for i in 2..(::RAND_BENCH_N as u32 + 2) { + accum ^= rng.gen_ratio(i, i + 1); + } + accum + }) +} + #[bench] fn misc_bernoulli_const(b: &mut Bencher) { let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap(); diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs index 2361fac0c21..8bf70c82558 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distributions/bernoulli.rs @@ -67,6 +67,28 @@ impl Bernoulli { }; Bernoulli { p_int } } + + /// Construct a new `Bernoulli` with the probability of success of + /// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return + /// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`. + /// + /// If `numerator == denominator` then the returned `Bernoulli` will always + /// return `true`. If `numerator == 0` it will always return `false`. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. + /// + #[inline] + pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli { + assert!(numerator <= denominator); + if numerator == denominator { + return Bernoulli { p_int: ::core::u64::MAX } + } + const SCALE: f64 = 2.0 * (1u64 << 63) as f64; + let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64; + Bernoulli { p_int } + } } impl Distribution for Bernoulli { @@ -103,18 +125,27 @@ mod test { #[test] fn test_average() { const P: f64 = 0.3; - let d = Bernoulli::new(P); - const N: u32 = 10_000_000; + const NUM: u32 = 3; + const DENOM: u32 = 10; + let d1 = Bernoulli::new(P); + let d2 = Bernoulli::from_ratio(NUM, DENOM); + const N: u32 = 100_000; - let mut sum: u32 = 0; + let mut sum1: u32 = 0; + let mut sum2: u32 = 0; let mut rng = ::test::rng(2); for _ in 0..N { - if d.sample(&mut rng) { - sum += 1; + if d1.sample(&mut rng) { + sum1 += 1; + } + if d2.sample(&mut rng) { + sum2 += 1; } } - let avg = (sum as f64) / (N as f64); + let avg1 = (sum1 as f64) / (N as f64); + assert!((avg1 - P).abs() < 5e-3); - assert!((avg - P).abs() < 1e-3); + let avg2 = (sum2 as f64) / (N as f64); + assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3); } } diff --git a/src/lib.rs b/src/lib.rs index 3723754cf7b..fb85cbbf89d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -401,7 +401,7 @@ pub trait Rng: RngCore { /// ``` /// /// [`Uniform`]: distributions/uniform/struct.Uniform.html - fn gen_range(&mut self, low: T, high: T) -> T { + fn gen_range(&mut self, low: T, high: T) -> T { T::Sampler::sample_single(low, high, self) } @@ -523,8 +523,6 @@ pub trait Rng: RngCore { /// Return a bool with a probability `p` of being true. /// - /// This is a wrapper around [`distributions::Bernoulli`]. - /// /// # Example /// /// ``` @@ -536,15 +534,38 @@ pub trait Rng: RngCore { /// /// # Panics /// - /// If `p` < 0 or `p` > 1. - /// - /// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html + /// If `p < 0` or `p > 1`. #[inline] fn gen_bool(&mut self, p: f64) -> bool { let d = distributions::Bernoulli::new(p); self.sample(d) } + /// Return a bool with a probability of `numerator/denominator` of being + /// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of + /// returning true. If `numerator == denominator`, then the returned value + /// is guaranteed to be `true`. If `numerator == 0`, then the returned + /// value is guaranteed to be `false`. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// println!("{}", rng.gen_ratio(2, 3)); + /// ``` + /// + #[inline] + fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool { + let d = distributions::Bernoulli::from_ratio(numerator, denominator); + self.sample(d) + } + /// Return a random element from `values`. /// /// Return `None` if `values` is empty. @@ -1196,4 +1217,21 @@ mod test { (u8, i8, u16, i16, u32, i32, u64, i64), (f32, (f64, (f64,)))) = random(); } + + #[test] + fn test_gen_ratio_average() { + const NUM: u32 = 3; + const DENOM: u32 = 10; + const N: u32 = 100_000; + + let mut sum: u32 = 0; + let mut rng = rng(111); + for _ in 0..N { + if rng.gen_ratio(NUM, DENOM) { + sum += 1; + } + } + let avg = (sum as f64) / (N as f64); + assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3); + } }