Skip to content

Commit

Permalink
Merge pull request #96 from fizyk20/discrete
Browse files Browse the repository at this point in the history
Add binomial and Poisson distributions
  • Loading branch information
dhardy authored Mar 18, 2018
2 parents b146ee6 + 38ee0f8 commit 8558b22
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 0 deletions.
156 changes: 156 additions & 0 deletions src/distributions/binomial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The binomial distribution.

use Rng;
use distributions::Distribution;
use distributions::log_gamma::log_gamma;
use std::f64::consts::PI;

/// The binomial distribution `Binomial(n, p)`.
///
/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
///
/// # Example
///
/// ```rust
/// use rand::distributions::{Binomial, Distribution};
///
/// let bin = Binomial::new(20, 0.3);
/// let v = bin.sample(&mut rand::thread_rng());
/// println!("{} is from a binomial distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Binomial {
n: u64, // number of trials
p: f64, // probability of success
}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
pub fn new(n: u64, p: f64) -> Binomial {
assert!(p > 0.0, "Binomial::new called with p <= 0");
assert!(p < 1.0, "Binomial::new called with p >= 1");
Binomial { n: n, p: p }
}
}

impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
// switch p so that it is less than 0.5 - this allows for lower expected values
// we will just invert the result at the end
let p = if self.p <= 0.5 {
self.p
} else {
1.0 - self.p
};

// expected value of the sample
let expected = self.n as f64 * p;

let result =
// for low expected values we just simulate n drawings
if expected < 25.0 {
let mut lresult = 0.0;
for _ in 0 .. self.n {
if rng.gen::<f64>() < p {
lresult += 1.0;
}
}
lresult
}
// high expected value - do the rejection method
else {
// prepare some cached values
let float_n = self.n as f64;
let ln_fact_n = log_gamma(float_n + 1.0);
let pc = 1.0 - p;
let log_p = p.ln();
let log_pc = pc.ln();
let sq = (expected * (2.0 * pc)).sqrt();

let mut lresult;

loop {
let mut comp_dev: f64;
// we use the lorentzian distribution as the comparison distribution
// f(x) ~ 1/(1+x/^2)
loop {
// draw from the lorentzian distribution
comp_dev = (PI*rng.gen::<f64>()).tan();
// shift the peak of the comparison ditribution
lresult = expected + sq * comp_dev;
// repeat the drawing until we are in the range of possible values
if lresult >= 0.0 && lresult < float_n + 1.0 {
break;
}
}

// the result should be discrete
lresult = lresult.floor();

let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
// this is the binomial probability divided by the comparison probability
// we will generate a uniform random value and if it is larger than this,
// we interpret it as a value falling out of the distribution and repeat
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));

if comparison_coeff >= rng.gen() {
break;
}
}

lresult
};

// invert the result for p < 0.5
if p != self.p {
self.n - result as u64
} else {
result as u64
}
}
}

#[cfg(test)]
mod test {
use distributions::Distribution;
use super::Binomial;

#[test]
fn test_binomial() {
let binomial = Binomial::new(150, 0.1);
let mut rng = ::test::rng(123);
let mut sum = 0;
for _ in 0..1000 {
sum += binomial.sample(&mut rng);
}
let avg = (sum as f64) / 1000.0;
println!("Binomial average: {}", avg);
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
}

#[test]
#[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_binomial_invalid_lambda_zero() {
Binomial::new(20, 0.0);
}
#[test]
#[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0);
}
}
51 changes: 51 additions & 0 deletions src/distributions/log_gamma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

/// Calculates ln(gamma(x)) (natural logarithm of the gamma
/// function) using the Lanczos approximation.
///
/// The approximation expresses the gamma function as:
/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
/// `g` is an arbitrary constant; we use the approximation with `g=5`.
///
/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
///
/// `Ag(z)` is an infinite series with coefficients that can be calculated
/// ahead of time - we use just the first 6 terms, which is good enough
/// for most purposes.
pub fn log_gamma(x: f64) -> f64 {
// precalculated 6 coefficients for the first 6 terms of the series
let coefficients: [f64; 6] = [
76.18009172947146,
-86.50532032941677,
24.01409824083091,
-1.231739572450155,
0.1208650973866179e-2,
-0.5395239384953e-5,
];

// (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
let tmp = x + 5.5;
let log = (x + 0.5) * tmp.ln() - tmp;

// the first few terms of the series for Ag(x)
let mut a = 1.000000000190015;
let mut denom = x;
for j in 0..6 {
denom += 1.0;
a += coefficients[j] / denom;
}

// get everything together
// a is Ag(x)
// 2.5066... is sqrt(2pi)
return log + (2.5066282746310005 * a / x).ln();
}
9 changes: 9 additions & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
pub use self::normal::{Normal, LogNormal, StandardNormal};
#[cfg(feature="std")]
pub use self::exponential::{Exp, Exp1};
#[cfg(feature = "std")]
pub use self::poisson::Poisson;
#[cfg(feature = "std")]
pub use self::binomial::Binomial;

pub mod range;
#[cfg(feature="std")]
Expand All @@ -33,9 +37,14 @@ pub mod gamma;
pub mod normal;
#[cfg(feature="std")]
pub mod exponential;
#[cfg(feature = "std")]
pub mod poisson;
#[cfg(feature = "std")]
pub mod binomial;

mod float;
mod integer;
mod log_gamma;
mod other;
#[cfg(feature="std")]
mod ziggurat_tables;
Expand Down
144 changes: 144 additions & 0 deletions src/distributions/poisson.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The Poisson distribution.

use Rng;
use distributions::Distribution;
use distributions::log_gamma::log_gamma;
use std::f64::consts::PI;

/// The Poisson distribution `Poisson(lambda)`.
///
/// This distribution has a density function:
/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`.
///
/// # Example
///
/// ```rust
/// use rand::distributions::{Poisson, Distribution};
///
/// let poi = Poisson::new(2.0);
/// let v = poi.sample(&mut rand::thread_rng());
/// println!("{} is from a Poisson(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Poisson {
lambda: f64,
// precalculated values
exp_lambda: f64,
log_lambda: f64,
sqrt_2lambda: f64,
magic_val: f64,
}

impl Poisson {
/// Construct a new `Poisson` with the given shape parameter
/// `lambda`. Panics if `lambda <= 0`.
pub fn new(lambda: f64) -> Poisson {
assert!(lambda > 0.0, "Poisson::new called with lambda <= 0");
let log_lambda = lambda.ln();
Poisson {
lambda: lambda,
exp_lambda: (-lambda).exp(),
log_lambda: log_lambda,
sqrt_2lambda: (2.0 * lambda).sqrt(),
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda),
}
}
}

impl Distribution<u64> for Poisson {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
// using the algorithm from Numerical Recipes in C

// for low expected values use the Knuth method
if self.lambda < 12.0 {
let mut result = 0;
let mut p = 1.0;
while p > self.exp_lambda {
p *= rng.gen::<f64>();
result += 1;
}
result - 1
}
// high expected values - rejection method
else {
let mut int_result: u64;

loop {
let mut result;
let mut comp_dev;

// we use the lorentzian distribution as the comparison distribution
// f(x) ~ 1/(1+x/^2)
loop {
// draw from the lorentzian distribution
comp_dev = (PI * rng.gen::<f64>()).tan();
// shift the peak of the comparison ditribution
result = self.sqrt_2lambda * comp_dev + self.lambda;
// repeat the drawing until we are in the range of possible values
if result >= 0.0 {
break;
}
}
// now the result is a random variable greater than 0 with Lorentzian distribution
// the result should be an integer value
result = result.floor();
int_result = result as u64;

// this is the ratio of the Poisson distribution to the comparison distribution
// the magic value scales the distribution function to a range of approximately 0-1
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
// this doesn't change the resulting distribution, only increases the rate of failed drawings
let check = 0.9 * (1.0 + comp_dev * comp_dev)
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp();

// check with uniform random value - if below the threshold, we are within the target distribution
if rng.gen::<f64>() <= check {
break;
}
}
int_result
}
}
}

#[cfg(test)]
mod test {
use distributions::Distribution;
use super::Poisson;

#[test]
fn test_poisson() {
let poisson = Poisson::new(10.0);
let mut rng = ::test::rng(123);
let mut sum = 0;
for _ in 0..1000 {
sum += poisson.sample(&mut rng);
}
let avg = (sum as f64) / 1000.0;
println!("Poisson average: {}", avg);
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
}

#[test]
#[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_poisson_invalid_lambda_zero() {
Poisson::new(0.0);
}
#[test]
#[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_poisson_invalid_lambda_neg() {
Poisson::new(-10.0);
}
}

0 comments on commit 8558b22

Please sign in to comment.