-
Notifications
You must be signed in to change notification settings - Fork 432
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Poisson and binomial distributions
- Loading branch information
Showing
4 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! The binomial distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
use distributions::log_gamma::log_gamma; | ||
use std::f64::consts::PI; | ||
|
||
/// The binomial distribution `Binomial(n, p)`. | ||
/// | ||
/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rand::distributions::{Binomial, Distribution}; | ||
/// | ||
/// let bin = Binomial::new(20, 0.3); | ||
/// let v = bin.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a binomial distribution", v); | ||
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Binomial { | ||
n: u64, // number of trials | ||
p: f64, // probability of success | ||
} | ||
|
||
impl Binomial { | ||
/// Construct a new `Binomial` with the given shape parameters | ||
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`. | ||
pub fn new(n: u64, p: f64) -> Binomial { | ||
assert!(p > 0.0, "Binomial::new called with `p` <= 0"); | ||
assert!(p < 1.0, "Binomial::new called with `p` >= 1"); | ||
Binomial { n: n, p: p } | ||
} | ||
} | ||
|
||
impl Distribution<u64> for Binomial { | ||
fn sample<R: Rng>(&self, rng: &mut R) -> u64 { | ||
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k | ||
// switch p so that it is less than 0.5 - this allows for lower expected values | ||
// we will just invert the result at the end | ||
let p = if self.p <= 0.5 { | ||
self.p | ||
} else { | ||
1.0 - self.p | ||
}; | ||
|
||
// expected value of the sample | ||
let expected = self.n as f64 * p; | ||
|
||
let result = | ||
// for low expected values we just simulate n drawings | ||
if expected < 25.0 { | ||
let mut lresult = 0.0; | ||
for _ in 0 .. self.n { | ||
if rng.gen::<f64>() < p { | ||
lresult += 1.0; | ||
} | ||
} | ||
lresult | ||
} | ||
// high expected value - do the rejection method | ||
else { | ||
// prepare some cached values | ||
let float_n = self.n as f64; | ||
let ln_fact_n = log_gamma(float_n + 1.0); | ||
let pc = 1.0 - p; | ||
let log_p = p.ln(); | ||
let log_pc = pc.ln(); | ||
let sq = (expected * (2.0 * pc)).sqrt(); | ||
|
||
let mut lresult; | ||
|
||
loop { | ||
let mut comp_dev: f64; | ||
// we use the lorentzian distribution as the comparison distribution | ||
// f(x) ~ 1/(1+x/^2) | ||
loop { | ||
// draw from the lorentzian distribution | ||
comp_dev = (PI*rng.gen::<f64>()).tan(); | ||
// shift the peak of the comparison ditribution | ||
lresult = expected + sq * comp_dev; | ||
// repeat the drawing until we are in the range of possible values | ||
if lresult >= 0.0 && lresult < float_n + 1.0 { | ||
break; | ||
} | ||
} | ||
|
||
// the result should be discrete | ||
lresult = lresult.floor(); | ||
|
||
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) - | ||
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc; | ||
// this is the binomial probability divided by the comparison probability | ||
// we will generate a uniform random value and if it is larger than this, | ||
// we interpret it as a value falling out of the distribution and repeat | ||
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev)); | ||
|
||
if comparison_coeff >= rng.gen() { | ||
break; | ||
} | ||
} | ||
|
||
lresult | ||
}; | ||
|
||
// invert the result for p < 0.5 | ||
if p != self.p { | ||
self.n - result as u64 | ||
} else { | ||
result as u64 | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use distributions::Distribution; | ||
use super::Binomial; | ||
|
||
#[test] | ||
fn test_binomial() { | ||
let binomial = Binomial::new(150, 0.1); | ||
let mut rng = ::test::rng(123); | ||
let mut sum = 0; | ||
for _ in 0..1000 { | ||
sum += binomial.sample(&mut rng); | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Binomial average: {}", avg); | ||
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_binomial_invalid_lambda_zero() { | ||
Binomial::new(20, 0.0); | ||
} | ||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_binomial_invalid_lambda_neg() { | ||
Binomial::new(20, -10.0); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
/// Calculates ln(gamma(x)) (natural logarithm of the gamma | ||
/// function) using the Lanczos approximation with g=5 | ||
pub fn log_gamma(x: f64) -> f64 { | ||
// precalculated 6 coefficients for the first 6 terms of the series | ||
let coefficients: [f64; 6] = [ | ||
76.18009172947146, | ||
-86.50532032941677, | ||
24.01409824083091, | ||
-1.231739572450155, | ||
0.1208650973866179e-2, | ||
-0.5395239384953e-5, | ||
]; | ||
|
||
// ln((x+g+0.5)^(x+0.5)*exp(-(x+g+0.5))) | ||
let tmp = x + 5.5; | ||
let log = (x + 0.5) * tmp.ln() - tmp; | ||
|
||
// the first few terms of the series | ||
let mut a = 1.000000000190015; | ||
let mut denom = x; | ||
for j in 0..6 { | ||
denom += 1.0; | ||
a += coefficients[j] / denom; | ||
} | ||
|
||
// get everything together | ||
// division by x is because the series is actually for gamma(x+1) = x*gamma(x) | ||
return log + (2.5066282746310005 * a / x).ln(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! The Poisson distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
use distributions::log_gamma::log_gamma; | ||
use std::f64::consts::PI; | ||
|
||
/// The Poisson distribution `Poisson(lambda)`. | ||
/// | ||
/// This distribution has density function: `f(k) = lambda^k * | ||
/// exp(-lambda) / k!` for `k >= 0`. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rand::distributions::{Poisson, Distribution}; | ||
/// | ||
/// let poi = Poisson::new(2.0); | ||
/// let v = poi.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a Poisson(2) distribution", v); | ||
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Poisson { | ||
lambda: f64, | ||
// precalculated values | ||
exp_lambda: f64, | ||
log_lambda: f64, | ||
magic_val: f64, | ||
} | ||
|
||
impl Poisson { | ||
/// Construct a new `Poisson` with the given shape parameter | ||
/// `lambda`. Panics if `lambda <= 0`. | ||
pub fn new(lambda: f64) -> Poisson { | ||
assert!(lambda > 0.0, "Poisson::new called with `lambda` <= 0"); | ||
Poisson { | ||
lambda: lambda, | ||
exp_lambda: (-lambda).exp(), | ||
log_lambda: lambda.ln(), | ||
magic_val: lambda * lambda.ln() - log_gamma(1.0 + lambda), | ||
} | ||
} | ||
} | ||
|
||
impl Distribution<u64> for Poisson { | ||
fn sample<R: Rng>(&self, rng: &mut R) -> u64 { | ||
// using the algorithm from Numerical Recipes in C | ||
|
||
// for low expected values use the Knuth method | ||
if self.lambda < 12.0 { | ||
let mut result = 0; | ||
let mut p = 1.0; | ||
while p > self.exp_lambda { | ||
p *= rng.gen::<f64>(); | ||
result += 1; | ||
} | ||
result - 1 | ||
} | ||
// high expected values - rejection method | ||
else { | ||
// some cached values | ||
let tmp = (2.0 * self.lambda).sqrt(); | ||
let mut int_result: u64; | ||
|
||
loop { | ||
let mut result; | ||
let mut comp_dev; | ||
|
||
// we use the lorentzian distribution as the comparison distribution | ||
// f(x) ~ 1/(1+x/^2) | ||
loop { | ||
// draw from the lorentzian distribution | ||
comp_dev = (PI * rng.gen::<f64>()).tan(); | ||
// shift the peak of the comparison ditribution | ||
result = tmp * comp_dev + self.lambda; | ||
// repeat the drawing until we are in the range of possible values | ||
if result >= 0.0 { | ||
break; | ||
} | ||
} | ||
// now the result is a random variable greater than 0 with Lorentzian distribution | ||
// the result should be an integer value | ||
result = result.floor(); | ||
int_result = result as u64; | ||
|
||
// this is the ratio of the Poisson distribution to the comparison distribution | ||
// the magic value scales the distribution function to a range of approximately 0-1 | ||
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 | ||
// this doesn't change the resulting distribution, only increases the rate of failed drawings | ||
let check = 0.9 * (1.0 + comp_dev * comp_dev) | ||
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp(); | ||
|
||
// check with uniform random value - if below the threshold, we are within the target distribution | ||
if rng.gen::<f64>() <= check { | ||
break; | ||
} | ||
} | ||
int_result | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use distributions::Distribution; | ||
use super::Poisson; | ||
|
||
#[test] | ||
fn test_poisson() { | ||
let poisson = Poisson::new(10.0); | ||
let mut rng = ::test::rng(123); | ||
let mut sum = 0; | ||
for _ in 0..1000 { | ||
sum += poisson.sample(&mut rng); | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Poisson average: {}", avg); | ||
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_poisson_invalid_lambda_zero() { | ||
Poisson::new(0.0); | ||
} | ||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_poisson_invalid_lambda_neg() { | ||
Poisson::new(-10.0); | ||
} | ||
} |