Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Multivariate normal distribution #583

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
Cargo.lock
target/
.idea/
8 changes: 8 additions & 0 deletions ndarray-rand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ ndarray = { version = "0.13", path = ".." }
rand_distr = "0.2.1"
quickcheck = { version = "0.9", default-features = false, optional = true }

[features]
normaldist = ["ndarray-linalg"]

[dependencies.ndarray-linalg]
version = "0.11"
optional = true
features = ["openblas"]

[dependencies.rand]
version = "0.7.0"
features = ["small_rng"]
Expand Down
4 changes: 3 additions & 1 deletion ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub mod rand_distr {
pub use rand_distr::*;
}

pub mod normal;

/// Constructors for n-dimensional arrays with random elements.
///
/// This trait extends ndarray’s `ArrayBase` and can not be implemented
Expand Down Expand Up @@ -90,7 +92,7 @@ where
IdS: Distribution<S::Elem>,
Sh: ShapeBuilder<Dim = D>;

/// Create an array with shape `dim` with elements drawn from
/// Create an array with shape `shape` with elements drawn from
/// `distribution`, using a specific Rng `rng`.
///
/// ***Panics*** if the number of elements overflows usize.
Expand Down
55 changes: 55 additions & 0 deletions ndarray-rand/src/normal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! Implementation of the multiavariate normal distribution.
use crate::RandomExt;
use ndarray::{Array, IntoDimension, Dimension};
use crate::rand::Rng;
use crate::rand::distributions::Distribution;
use crate::rand_distr::{StandardNormal};

#[cfg(feature = "normaldist")]
pub mod advanced;

/// Standard multivariate normal distribution `N(0,1)` for any-dimensional arrays.
///
/// ```
/// use rand;
/// use rand_distr::Distribution;
/// use ndarray;
/// use ndarray_rand::normal::MultivariateStandardNormal;
///
/// let shape = (2, 3); // create (2,3)-matrix of standard normal variables
/// let n = MultivariateStandardNormal::new(shape);
/// let ref mut rng = rand::thread_rng();
/// println!("{:?}", n.sample(rng));
/// ```
pub struct MultivariateStandardNormal<D>
where D: Dimension
{
shape: D
}

impl<D> MultivariateStandardNormal<D>
where D: Dimension
{
pub fn new<Sh>(shape: Sh) -> Self
where Sh: IntoDimension<Dim=D>
{
MultivariateStandardNormal {
shape: shape.into_dimension()
}
}

pub fn shape(&self) -> D {
self.shape.clone()
}
}

impl<D> Distribution<Array<f64, D>> for MultivariateStandardNormal<D>
where D: Dimension
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array<f64, D> {
let shape = self.shape.clone();
let res = Array::random_using(
shape, StandardNormal, rng);
res
}
}
53 changes: 53 additions & 0 deletions ndarray-rand/src/normal/advanced.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/// The normal distribution `N(mean, covariance)`.
use rand::Rng;
use rand::distributions::{
Distribution, StandardNormal
};

use ndarray::prelude::*;
use ndarray_linalg::error::Result as LAResult;

/// Multivariate normal distribution for 1D arrays,
/// with mean vector and covariance matrix.
pub struct MultivariateNormal {
shape: Ix1,
mean: Array1<f64>,
covariance: Array2<f64>,
/// Lower triangular matrix (Cholesky decomposition of the coviariance matrix)
lower: Array2<f64>
}

impl MultivariateNormal {
pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> LAResult<Self> {
let shape: Ix1 = Ix1(mean.shape()[0]);
use ndarray_linalg::cholesky::*;
let lower = covariance.cholesky(UPLO::Lower)?;
Ok(MultivariateNormal {
shape, mean, covariance, lower
})
}

pub fn shape(&self) -> Ix1 {
self.shape
}

pub fn mean(&self) -> ArrayView1<f64> {
self.mean.view()
}

pub fn covariance(&self) -> ArrayView2<f64> {
self.covariance.view()
}
}

impl Distribution<Array1<f64>> for MultivariateNormal {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let shape = self.shape.clone();
// standard normal distribution
use crate::RandomExt;
let res = Array1::random_using(
shape, StandardNormal, rng);
// use Cholesky decomposition to obtain a sample of our general multivariate normal
self.mean.clone() + self.lower.view().dot(&res)
}
}
25 changes: 25 additions & 0 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use ndarray_rand::rand::{distributions::Distribution, thread_rng};

use ndarray::ShapeBuilder;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::normal::MultivariateStandardNormal;
use ndarray_rand::{RandomExt, SamplingStrategy};
use quickcheck::quickcheck;

Expand Down Expand Up @@ -36,6 +37,30 @@ fn test_dim_f() {
}

#[test]
fn test_standard_normal() {
let shape = 2usize;
let n = MultivariateStandardNormal::new(shape);
let ref mut rng = rand::thread_rng();
let s: ndarray::Array1<f64> = n.sample(rng);
assert_eq!(s.shape(), &[2]);
}

#[cfg(features = "normaldist")]
#[test]
fn test_normal() {
use ndarray::IntoDimension;
use ndarray::{Array1, arr2};
use ndarray_rand::normal::advanced::MultivariateNormal;
let mean = Array1::from_vec([1., 0.]);
let covar = arr2([
[1., 0.8], [0.8, 1.]]);
let ref mut rng = rand::thread_rng();
let n = MultivariateNormal::new(mean, covar);
if let Ok(n) = n {
let x = n.sample(rng);
assert_eq!(x.shape(), &[2, 2]);
}
}
#[should_panic]
fn oversampling_without_replacement_should_panic() {
let m = 5;
Expand Down