From 15cb13296edac6f845bdd3e093f7111ef8375672 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Mon, 30 Jul 2018 14:03:01 +0200 Subject: [PATCH] Implement Beta distribution --- src/distributions/gamma.rs | 56 +++++++++++++++++++++++++++++++++++++- src/distributions/mod.rs | 7 +++-- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/distributions/gamma.rs b/src/distributions/gamma.rs index f02cf3b21e2..b5bd960a4ab 100644 --- a/src/distributions/gamma.rs +++ b/src/distributions/gamma.rs @@ -305,10 +305,49 @@ impl Distribution for StudentT { } } +/// The Beta distribution with shape parameters `alpha` and `beta`. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::{Distribution, Beta}; +/// +/// let beta = Beta::new(2.0, 5.0); +/// let v = beta.sample(&mut rand::thread_rng()); +/// println!("{} is from a Beta(2, 5) distribution", v); +/// ``` +#[derive(Clone, Copy, Debug)] +pub struct Beta { + gamma_a: Gamma, + gamma_b: Gamma, +} + +impl Beta { + /// Construct an object representing the `Beta(alpha, beta)` + /// distribution. + /// + /// Panics if `shape <= 0` or `scale <= 0`. + pub fn new(alpha: f64, beta: f64) -> Beta { + assert!((alpha > 0.) & (beta > 0.)); + Beta { + gamma_a: Gamma::new(alpha, 1.), + gamma_b: Gamma::new(beta, 1.), + } + } +} + +impl Distribution for Beta { + fn sample(&self, rng: &mut R) -> f64 { + let x = self.gamma_a.sample(rng); + let y = self.gamma_b.sample(rng); + x / (x + y) + } +} + #[cfg(test)] mod test { use distributions::Distribution; - use super::{ChiSquared, StudentT, FisherF}; + use super::{Beta, ChiSquared, StudentT, FisherF}; #[test] fn test_chi_squared_one() { @@ -357,4 +396,19 @@ mod test { t.sample(&mut rng); } } + + #[test] + fn test_beta() { + let beta = Beta::new(1.0, 2.0); + let mut rng = ::test::rng(201); + for _ in 0..1000 { + beta.sample(&mut rng); + } + } + + #[test] + #[should_panic] + fn test_beta_invalid_dof() { + Beta::new(0., 0.); + } } diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 9afc2bc2e87..da862eba1f3 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -99,7 +99,8 @@ //! - [`StudentT`] distribution //! - [`FisherF`] distribution //! - Triangular distribution: -//! - [`Triangular`] distribution +//! - [`Beta`] distribution +//! - [`Triangular`] distribution //! - Multivariate probability distributions //! - [`Dirichlet`] distribution //! - [`UnitSphereSurface`] distribution @@ -153,6 +154,7 @@ // distributions //! [`Alphanumeric`]: struct.Alphanumeric.html //! [`Bernoulli`]: struct.Bernoulli.html +//! [`Beta`]: struct.Beta.html //! [`Binomial`]: struct.Binomial.html //! [`Cauchy`]: struct.Cauchy.html //! [`ChiSquared`]: struct.ChiSquared.html @@ -187,7 +189,8 @@ pub use self::bernoulli::Bernoulli; #[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; -#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT}; +#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, + StudentT, Beta}; #[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal}; #[cfg(feature="std")] pub use self::exponential::{Exp, Exp1}; #[cfg(feature="std")] pub use self::pareto::Pareto;