Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for getting a bool with chance of exactly 1-in-10 or 2-in-3 #491

Merged
merged 1 commit into from
Jun 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ fn misc_gen_bool_var(b: &mut Bencher) {
})
}

#[bench]
fn misc_gen_ratio_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_ratio(2, 3);
}
accum
})
}

#[bench]
fn misc_gen_ratio_var(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut accum = true;
for i in 2..(::RAND_BENCH_N as u32 + 2) {
accum ^= rng.gen_ratio(i, i + 1);
}
accum
})
}

#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
Expand Down
45 changes: 38 additions & 7 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ impl Bernoulli {
};
Bernoulli { p_int }
}

/// Construct a new `Bernoulli` with the probability of success of
/// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return
/// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`.
///
/// If `numerator == denominator` then the returned `Bernoulli` will always
/// return `true`. If `numerator == 0` it will always return `false`.
///
/// # Panics
///
/// If `denominator == 0` or `numerator > denominator`.
///
#[inline]
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
assert!(numerator <= denominator);
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
}
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
}
}

impl Distribution<bool> for Bernoulli {
Expand Down Expand Up @@ -103,18 +125,27 @@ mod test {
#[test]
fn test_average() {
const P: f64 = 0.3;
let d = Bernoulli::new(P);
const N: u32 = 10_000_000;
const NUM: u32 = 3;
const DENOM: u32 = 10;
let d1 = Bernoulli::new(P);
let d2 = Bernoulli::from_ratio(NUM, DENOM);
const N: u32 = 100_000;

let mut sum: u32 = 0;
let mut sum1: u32 = 0;
let mut sum2: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d.sample(&mut rng) {
sum += 1;
if d1.sample(&mut rng) {
sum1 += 1;
}
if d2.sample(&mut rng) {
sum2 += 1;
}
}
let avg = (sum as f64) / (N as f64);
let avg1 = (sum1 as f64) / (N as f64);
assert!((avg1 - P).abs() < 5e-3);

assert!((avg - P).abs() < 1e-3);
let avg2 = (sum2 as f64) / (N as f64);
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
}
}
55 changes: 51 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ pub trait Rng: RngCore {
/// ```
///
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
fn gen_range<T: PartialOrd + SampleUniform>(&mut self, low: T, high: T) -> T {
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

T::Sampler::sample_single(low, high, self)
}

Expand Down Expand Up @@ -509,7 +509,8 @@ pub trait Rng: RngCore {

/// Return a bool with a probability `p` of being true.
///
/// This is a wrapper around [`distributions::Bernoulli`].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please at least mention this in the doc; same with gen_ratio below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a comment, so I'll add this back if really desired. However I don't think it's a good idea to promise a particular implementation strategy. That just seems like complicate future arguments regarding if changing the implementation is a breaking change or not.

If we want a pointer to Bernoulli, maybe do what we do for gen_range and say: See also the [Bernoulli] distribution type which may be faster if sampling with the same probability repeatedly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just a pointer to Bernoulli is sufficient. You're right that we don't need to specify this constraint in the doc, though I don't see it as a big deal.

/// See also the [`Bernoulli`] distribution, which may be faster if
/// sampling from the same probability repeatedly.
///
/// # Example
///
Expand All @@ -522,15 +523,44 @@ pub trait Rng: RngCore {
///
/// # Panics
///
/// If `p` < 0 or `p` > 1.
/// If `p < 0` or `p > 1`.
///
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_bool(&mut self, p: f64) -> bool {
let d = distributions::Bernoulli::new(p);
self.sample(d)
}

/// Return a bool with a probability of `numerator/denominator` of being
/// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of
/// returning true. If `numerator == denominator`, then the returned value
/// is guaranteed to be `true`. If `numerator == 0`, then the returned
/// value is guaranteed to be `false`.
///
/// See also the [`Bernoulli`] distribution, which may be faster if
/// sampling from the same `numerator` and `denominator` repeatedly.
///
/// # Panics
///
/// If `denominator == 0` or `numerator > denominator`.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
///
/// let mut rng = thread_rng();
/// println!("{}", rng.gen_ratio(2, 3));
/// ```
///
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool {
let d = distributions::Bernoulli::from_ratio(numerator, denominator);
self.sample(d)
}

/// Return a random element from `values`.
///
/// Return `None` if `values` is empty.
Expand Down Expand Up @@ -1017,4 +1047,21 @@ mod test {
(u8, i8, u16, i16, u32, i32, u64, i64),
(f32, (f64, (f64,)))) = random();
}

#[test]
fn test_gen_ratio_average() {
const NUM: u32 = 3;
const DENOM: u32 = 10;
const N: u32 = 100_000;

let mut sum: u32 = 0;
let mut rng = rng(111);
for _ in 0..N {
if rng.gen_ratio(NUM, DENOM) {
sum += 1;
}
}
let avg = (sum as f64) / (N as f64);
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
}
}