Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: experiment for reusing allocated buffers for Dyn multivariate #278

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl Bernoulli {
}

impl std::fmt::Display for Bernoulli {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Bernoulli({})", self.p())
}
Expand Down
1 change: 1 addition & 0 deletions src/distribution/chi_squared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl ChiSquared {
}

impl std::fmt::Display for ChiSquared {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "χ^2_{}", self.freedom)
}
Expand Down
1 change: 1 addition & 0 deletions src/distribution/erlang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl Erlang {
}

impl std::fmt::Display for Erlang {
#[cfg_attr(coverage_nightly, coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "E({}, {})", self.rate(), self.shape())
}
Expand Down
162 changes: 98 additions & 64 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,76 +5,39 @@
use std::f64;
use std::f64::consts::{E, PI};

/// computes both the normalization and exponential argument in the normal distribution
/// # Errors
/// will error on dimension mismatch
pub(super) fn density_normalization_and_exponential<D>(
mu: &OVector<f64, D>,
cov: &OMatrix<f64, D, D>,
precision: &OMatrix<f64, D, D>,
x: &OVector<f64, D>,
) -> std::result::Result<(f64, f64), StatsError>
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>
+ nalgebra::allocator::Allocator<f64, D, D>
+ nalgebra::allocator::Allocator<(usize, usize), D>,
{
Ok((
density_distribution_pdf_const(mu, cov)?,
density_distribution_exponential(mu, precision, x)?,
))
}

/// computes the argument of the exponential term in the normal distribution
/// ```text
/// ```
/// # Errors
/// will error on dimension mismatch
/// # Panics
/// will panic on dimension mismatch
#[inline]
pub(super) fn density_distribution_exponential<D>(
pub(super) fn pdf_exponent_unchecked<D>(
YeungOnion marked this conversation as resolved.
Show resolved Hide resolved
mu: &OVector<f64, D>,
precision: &OMatrix<f64, D, D>,
x: &OVector<f64, D>,
) -> std::result::Result<f64, StatsError>
) -> f64
where
D: Dim,
nalgebra::DefaultAllocator:
nalgebra::allocator::Allocator<f64, D> + nalgebra::allocator::Allocator<f64, D, D>,
{
if x.shape_generic().0 != precision.shape_generic().0
|| x.shape_generic().0 != mu.shape_generic().0
|| !precision.is_square()
{
return Err(StatsError::ContainersMustBeSameLength);
}
let dv = x - mu;
let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv);
Ok(exp_term)
// TODO update to dimension mismatch error
-0.5 * (precision * &dv).dot(&dv)
}

/// computes the argument of the normalization term in the normal distribution
/// # Errors
/// will error on dimension mismatch
/// # Panics
/// will panic on dimension mismatch
#[inline]
pub(super) fn density_distribution_pdf_const<D>(
mu: &OVector<f64, D>,
cov: &OMatrix<f64, D, D>,
) -> std::result::Result<f64, StatsError>
pub(super) fn pdf_const_unchecked<D>(mu: &OVector<f64, D>, cov: &OMatrix<f64, D, D>) -> f64
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>
+ nalgebra::allocator::Allocator<f64, D, D>
+ nalgebra::allocator::Allocator<(usize, usize), D>,
{
if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() {
return Err(StatsError::ContainersMustBeSameLength);
}
let cov_det = cov.determinant();
Ok(((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
.recip()
.sqrt())
.sqrt()
}

/// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution)
Expand All @@ -92,18 +55,18 @@
/// assert_eq!(mvn.variance().unwrap(), matrix![1., 0.; 0., 1.]);
/// assert_eq!(mvn.pdf(&vector![1., 1.]), 0.05854983152431917);
/// ```
#[derive(Clone, PartialEq, Debug)]
#[derive(Clone, Debug)]
pub struct MultivariateNormal<D>
where
D: Dim,
nalgebra::DefaultAllocator:
nalgebra::allocator::Allocator<f64, D> + nalgebra::allocator::Allocator<f64, D, D>,
{
cov_chol_decomp: OMatrix<f64, D, D>,
mu: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
precision: OMatrix<f64, D, D>,
pdf_const: f64,
cov_chol: Option<Cholesky<f64, D>>,
precision: Option<OMatrix<f64, D, D>>,
}

/// Represents the errors that can occur when creating a [`MultivariateNormal`].
Expand Down Expand Up @@ -199,20 +162,95 @@
// for sampling
match Cholesky::new(cov.clone()) {
None => Err(MultivariateNormalError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateNormal {
// .unwrap() because prerequisites are already checked above
pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
cov_chol_decomp: cholesky_decomp.unpack(),
mu: mean,
cov_chol @ Some(_) => Ok(MultivariateNormal {
precision: cov_chol.as_ref().map(|ll| ll.inverse()),
cov_chol,
pdf_const: pdf_const_unchecked(&mean, &cov),
mu: mean,
cov,
}),
}
}

/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov` from an existing `MultivariateNormal` without
/// requiring reallocation.
///
/// # Errors
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn into_new_from_nalgebra(
self,
other_mean: &OVector<f64, D>,
other_cov: &OMatrix<f64, D, D>,
) -> Result<Self, StatsError> {
let Self {
mut mu,
mut cov,
pdf_const,
cov_chol,
..
} = self;

mu.clone_from(other_mean);
cov.clone_from(other_cov);

// clone if storage needed, take from previous cholesky otherwise
let other_cov = cov_chol.map_or((*other_cov).clone(), |chol| chol.unpack());

match other_cov.cholesky() {
None => Err(StatsError::BadParams),
cov_chol @ Some(_) => {
let precision = cov_chol.as_ref().map(|chol| chol.inverse());

Ok(Self {
mu,

Check warning on line 208 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L183-L208

Added lines #L183 - L208 were not covered by tests
cov,
pdf_const,
cov_chol,

Check warning on line 211 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L210-L211

Added lines #L210 - L211 were not covered by tests
precision,
})
}
}
}

/// updates the covariance of the distribution without requiring reallocation
///
/// interally uses `clone_from` and [`set_cov_with`]
/// # Errors
/// if dimensions change, then this will error with dimension mismatch
#[inline]
pub fn set_cov(&mut self, new_cov: &OMatrix<f64, D, D>) -> Result<(), StatsError> {
self.set_cov_with(|old_cov: &mut OMatrix<f64, D, D>| old_cov.clone_from(new_cov))
}

Check warning on line 226 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L224-L226

Added lines #L224 - L226 were not covered by tests

/// updates the covariance of the distribution without requiring reallocation
/// by calling a function
/// # Errors
/// if dimensions change, then this will error with dimension mismatch
pub fn set_cov_with(
&mut self,
f: impl FnOnce(&mut OMatrix<f64, D, D>),
) -> Result<(), StatsError> {
let old_shape = self.cov.shape_generic();
f(&mut self.cov);
if old_shape != self.cov.shape_generic() {
Err(StatsError::ContainersMustBeSameLength)

Check warning on line 239 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L232-L239

Added lines #L232 - L239 were not covered by tests
} else {
match self.cov_chol.take() {
None => (),
Some(l_old) => {
let mut l = l_old.unpack();
l.clone_from(&self.cov);
// ignore possible fallibility for now
self.cov_chol = Some(l.cholesky().unwrap());
}

Check warning on line 248 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L241-L248

Added lines #L241 - L248 were not covered by tests
}
Ok(())

Check warning on line 250 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L250

Added line #L250 was not covered by tests
}
}

Check warning on line 252 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L252

Added line #L252 was not covered by tests

/// Returns the entropy of the multivariate normal distribution
///
/// # Formula
Expand Down Expand Up @@ -266,7 +304,7 @@
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
let d = crate::distribution::Normal::new(0., 1.).unwrap();
let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng);
(&self.cov_chol_decomp * z) + &self.mu
self.cov_chol.as_ref().unwrap().l() * z + &self.mu

Check warning on line 307 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L307

Added line #L307 was not covered by tests
}
}

Expand Down Expand Up @@ -362,17 +400,13 @@
/// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant
/// of the covariance matrix, and `k` is the dimension of the distribution
fn pdf(&self, x: &OVector<f64, D>) -> f64 {
self.pdf_const
* density_distribution_exponential(&self.mu, &self.precision, x)
.unwrap()
.exp()
self.pdf_const * pdf_exponent_unchecked(&self.mu, self.precision.as_ref().unwrap(), x).exp()
}

/// Calculates the log probability density function for the multivariate
/// normal distribution at `x`. Equivalent to pdf(x).ln().
fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
self.pdf_const.ln()
+ density_distribution_exponential(&self.mu, &self.precision, x).unwrap()
self.pdf_const.ln() + pdf_exponent_unchecked(&self.mu, self.precision.as_ref().unwrap(), x)
}
}

Expand Down
22 changes: 6 additions & 16 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,9 @@ where
/// of the scale matrix, and `k` is the dimension of the distribution.
fn pdf(&self, x: &'a OVector<f64, D>) -> f64 {
if self.freedom.is_infinite() {
use super::multivariate_normal::density_normalization_and_exponential;
let (pdf_const, exp_arg) = density_normalization_and_exponential(
&self.location,
&self.scale,
&self.precision,
x,
)
.unwrap();
use super::multivariate_normal as mvn;
let pdf_const = mvn::pdf_const_unchecked(&self.location, &self.scale);
let exp_arg = mvn::pdf_exponent_unchecked(&self.location, &self.precision, x);
return pdf_const * exp_arg.exp();
}

Expand All @@ -367,14 +362,9 @@ where
/// student distribution at `x`. Equivalent to pdf(x).ln().
fn ln_pdf(&self, x: &'a OVector<f64, D>) -> f64 {
if self.freedom.is_infinite() {
use super::multivariate_normal::density_normalization_and_exponential;
let (pdf_const, exp_arg) = density_normalization_and_exponential(
&self.location,
&self.scale,
&self.precision,
x,
)
.unwrap();
use super::multivariate_normal as mvn;
let pdf_const = mvn::pdf_const_unchecked(&self.location, &self.scale);
let exp_arg = mvn::pdf_exponent_unchecked(&self.location, &self.precision, x);
return pdf_const.ln() + exp_arg;
}

Expand Down