Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Dirichlet distribution #485

Merged
merged 12 commits into from
Jun 12, 2018
139 changes: 139 additions & 0 deletions src/distributions/dirichlet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The dirichlet distribution.

use Rng;
use distributions::Distribution;
use distributions::gamma::Gamma;

/// The dirichelet distribution `Dirichlet(alpha)`.
///
/// The Dirichlet distribution } is a family of continuous multivariate probability distributions parameterized by
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray }

Copying a fancy description from Wikipedia doesn't really explain much, especially since the links are missing. Not that I have a better idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the Mathematica explanation a bit more than Wikipedia's.

/// a vector alpha of positive reals
/// It is a multivariate generalization of the beta distribution.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
/// use rand::distributions::Dirichlet;
///
/// let dirichlet = Dirichlet::new(&vec![1.0, 2.0, 3.0]);
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```

#[derive(Clone, Debug)]
pub struct Dirichlet {
/// Concentration parameters (alpha)
alpha: Vec<f64>,
Copy link
Member

@dhardy dhardy Jun 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, we could probably use alpha: [f64] here. It makes the type "unsized" (i.e. users have to write Box<Dirichlet>) but is more flexible (potentially more optimal).

On the other hand it may not be worth it since it makes the type less ergonomic to use for what is probably not a lot of gain.

Another option would be Dirichlet<N: usize> { alpha: [f64; N] } — except I don't think Rust supports that yet (though it would also allow sample(..) -> [f64; N], thus side-stepping @vks's concerns).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do our distributions even work when you write Box<Dirichlet>?

Copy link
Member

@dhardy dhardy Jun 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rng::sample won't of course but Distribution::sample will. Either way it's not really a great choice (less convenient for users).

}

impl Dirichlet {
/// Construct a new `Dirichlet` with the given alpha parameter
/// `alpha`. Panics if `alpha.len() < 2`.
#[inline]
pub fn new(alpha: &[f64]) -> Dirichlet {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you accept any vec-like thing here: new<V: Into<Vec<f64>>(alpha: V) and then let alpha = alpha.into();. Drop the to_vec() later.

assert!(
alpha.len() > 1,
"Dirichlet::new called with `alpha` with length < 2"
);
for i in 0..alpha.len() {
assert!(
alpha[i] > 0.0,
"Dirichlet::new called with `alpha` <= 0.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think you need the quotes. Maybe just write alpha[i] <= 0, and alpha.len() < 2 above.

);
}

Dirichlet {
alpha: alpha.to_vec(),
}
}

/// Construct a new `Dirichlet` with the given shape parameter and size
/// `alpha`. Panics if `alpha <= 0.0`.
/// `size` . Panic if `size < 2`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't render well. If you want a list, leave a blank line, the prefix each item with - (it's Markdown). Otherwise just rewrite as two sentences.

#[inline]
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
assert!(alpha > 0.0, "Dirichlet::new called with `alpha` <= 0.0");
assert!(size > 1, "Dirichlet::new called with `size` <= 1");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new_with_param not new. Again you can drop the extra quotes on parameter names.

Dirichlet {
alpha: vec![alpha; size],
}
}
}

impl Distribution<Vec<f64>> for Dirichlet {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure our current distribution trait is well suited for multivariate distributions. It would be nice to sample without allocating, but this requires different method. Something like fn sample_multi(&self, &mut Rng, &mut [f64]).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why you opened #496. I agree. On the other hand, I'm not too fussed about having to make breaking changes to this distribution later (it's still better for users than not having it, and we're not close to 1.0).

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
let n = self.alpha.len();
let mut samples = vec![0.0f64; n];
let mut sum = 0.0f64;

for i in 0..n {
let g = Gamma::new(self.alpha[i], 1.0);
samples[i] = g.sample(rng);
sum += samples[i];
}
let invacc = 1.0 / sum;
for i in 0..n {
samples[i] *= invacc;
}
samples
}
}

#[cfg(test)]
mod test {
use super::Dirichlet;
use distributions::Distribution;

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&vec![1.0, 2.0, 3.0]);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
fn test_dirichlet_with_param() {
let alpha = 0.5f64;
let size = 2;
let d = Dirichlet::new_with_param(alpha, size);
let mut rng = ::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new_with_param(0.5f64, 1);
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_param(0.0f64, 2);
}
}
8 changes: 7 additions & 1 deletion src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
//! - [`ChiSquared`] distribution
//! - [`StudentT`] distribution
//! - [`FisherF`] distribution
//!
//! - Dirichlet distribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The [`Dirichlet`] link is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been added now, but Dirichlet seems to be mentioned twice in a row

//!
//! # Examples
//!
Expand Down Expand Up @@ -148,6 +148,7 @@
//! [`Bernoulli`]: struct.Bernoulli.html
//! [`Binomial`]: struct.Binomial.html
//! [`ChiSquared`]: struct.ChiSquared.html
//! [`Dirichlet`]: struct.Dirichlet.html
//! [`Exp`]: struct.Exp.html
//! [`Exp1`]: struct.Exp1.html
//! [`FisherF`]: struct.FisherF.html
Expand Down Expand Up @@ -180,6 +181,8 @@ pub use self::uniform::Uniform as Range;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::binomial::Binomial;
#[doc(inline)] pub use self::bernoulli::Bernoulli;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="std")]
Expand All @@ -193,6 +196,8 @@ pub mod uniform;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod binomial;
#[doc(hidden)] pub mod bernoulli;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod dirichlet;

mod float;
mod integer;
Expand All @@ -204,6 +209,7 @@ mod ziggurat_tables;
#[cfg(feature="std")]
use distributions::float::IntoFloat;


/// Types that can be used to create a random instance of `Support`.
#[deprecated(since="0.5.0", note="use Distribution instead")]
pub trait Sample<Support> {
Expand Down