Skip to content

Commit

Permalink
Merge pull request #485 from rohitjoshi/master
Browse files Browse the repository at this point in the history
Support for Dirichlet distribution
  • Loading branch information
dhardy authored Jun 12, 2018
2 parents ec3d7ef + fde9567 commit 187b7d1
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 2 deletions.
138 changes: 138 additions & 0 deletions src/distributions/dirichlet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The dirichlet distribution.

use Rng;
use distributions::Distribution;
use distributions::gamma::Gamma;

/// The dirichelet distribution `Dirichlet(alpha)`.
///
/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by
/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution
/// It is a multivariate generalization of the beta distribution.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
/// use rand::distributions::Dirichlet;
///
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]);
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```

#[derive(Clone, Debug)]
pub struct Dirichlet {
/// Concentration parameters (alpha)
alpha: Vec<f64>,
}

impl Dirichlet {
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// # Panics
/// - if `alpha.len() < 2`
///
#[inline]
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
let a = alpha.into();
assert!(a.len() > 1);
for i in 0..a.len() {
assert!(a[i] > 0.0);
}

Dirichlet { alpha: a }
}

/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
///
/// # Panics
/// - if `alpha <= 0.0`
/// - if `size < 2`
///
#[inline]
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
assert!(alpha > 0.0);
assert!(size > 1);
Dirichlet {
alpha: vec![alpha; size],
}
}
}

impl Distribution<Vec<f64>> for Dirichlet {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
let n = self.alpha.len();
let mut samples = vec![0.0f64; n];
let mut sum = 0.0f64;

for i in 0..n {
let g = Gamma::new(self.alpha[i], 1.0);
samples[i] = g.sample(rng);
sum += samples[i];
}
let invacc = 1.0 / sum;
for i in 0..n {
samples[i] *= invacc;
}
samples
}
}

#[cfg(test)]
mod test {
use super::Dirichlet;
use distributions::Distribution;

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
fn test_dirichlet_with_param() {
let alpha = 0.5f64;
let size = 2;
let d = Dirichlet::new_with_param(alpha, size);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new_with_param(0.5f64, 1);
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_param(0.0f64, 2);
}
}
9 changes: 7 additions & 2 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
//! - Related to real-valued quantities that grow linearly
//! (e.g. errors, offsets):
//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive
//! - [`Cauchy`] distribution
//! - Related to Bernoulli trials (yes/no events, with a given probability):
//! - [`Binomial`] distribution
//! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`].
Expand All @@ -96,7 +95,8 @@
//! - [`ChiSquared`] distribution
//! - [`StudentT`] distribution
//! - [`FisherF`] distribution
//!
//! - Related to continuous multivariate probability distributions
//! - [`Dirichlet`] distribution
//!
//! # Examples
//!
Expand Down Expand Up @@ -150,6 +150,7 @@
//! [`Binomial`]: struct.Binomial.html
//! [`Cauchy`]: struct.Cauchy.html
//! [`ChiSquared`]: struct.ChiSquared.html
//! [`Dirichlet`]: struct.Dirichlet.html
//! [`Exp`]: struct.Exp.html
//! [`Exp1`]: struct.Exp1.html
//! [`FisherF`]: struct.FisherF.html
Expand Down Expand Up @@ -185,6 +186,8 @@ use Rng;
#[doc(inline)] pub use self::bernoulli::Bernoulli;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::cauchy::Cauchy;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="std")]
Expand All @@ -202,6 +205,8 @@ pub mod uniform;
#[doc(hidden)] pub mod bernoulli;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod cauchy;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod dirichlet;

mod float;
mod integer;
Expand Down

0 comments on commit 187b7d1

Please sign in to comment.