From f643520b1a0d3afa424b62d3874de6a7a311c7ac Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Nov 2020 12:12:20 -0800 Subject: [PATCH] Use un-pivoted Cholesky triangle to sample from sparse MvNormalCanon (#1218) * Use unpivoted Cholesky triangle to generate sparse MvNormalCanon samples * Remove stray `end` * Check that PDSparseMat is defined in tests --- src/multivariate/mvnormalcanon.jl | 4 +- test/mvnormal.jl | 66 +++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/multivariate/mvnormalcanon.jl b/src/multivariate/mvnormalcanon.jl index 3bf7cdc90..140d38ebd 100644 --- a/src/multivariate/mvnormalcanon.jl +++ b/src/multivariate/mvnormalcanon.jl @@ -177,9 +177,9 @@ unwhiten_winv!(J::AbstractPDMat, x::AbstractVecOrMat) = unwhiten!(inv(J), x) unwhiten_winv!(J::PDiagMat, x::AbstractVecOrMat) = whiten!(J, x) unwhiten_winv!(J::ScalMat, x::AbstractVecOrMat) = whiten!(J, x) if isdefined(PDMats, :PDSparseMat) - unwhiten_winv!(J::PDSparseMat, x::AbstractVecOrMat) = x[:] = J.chol.U \ x + unwhiten_winv!(J::PDSparseMat, x::AbstractVecOrMat) = x[:] = J.chol.PtL' \ x end - + _rand!(rng::AbstractRNG, d::MvNormalCanon, x::AbstractVector) = add!(unwhiten_winv!(d.J, randn!(rng,x)), d.μ) _rand!(rng::AbstractRNG, d::MvNormalCanon, x::AbstractMatrix) = diff --git a/test/mvnormal.jl b/test/mvnormal.jl index 7407f377a..be3d590e0 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -1,6 +1,9 @@ # Tests on Multivariate Normal distributions import PDMats: ScalMat, PDiagMat, PDMat +if isdefined(PDMats, :PDSparseMat) + import PDMats: PDSparseMat +end using Distributions using LinearAlgebra, Random, Test @@ -178,28 +181,51 @@ end end ##### Random sampling from MvNormalCanon with sparse precision matrix -@testset "Sparse MvNormalCanon random sampling" begin - # Random samples from MvNormalCanon and MvNormal diverge as - # 1) Dimension of cov/prec matrix increases (n) - # 2) Determinant of J increases - # ...hence, the relative tolerance for testing their equality - n = 10 - k = 0.1 - rtol = n*k^2 - seed = 1234 - J = sprandn(n, n, 0.25) * k - J = J'J + I - Σ = inv(Matrix(J)) - J = PDSparseMat(J) - μ = zeros(n) - d = MvNormalCanon(μ, J*μ, J) - d1 = MvNormal(μ, PDMat(Symmetric(Σ))) - r = rand(MersenneTwister(seed), d) - r1 = rand(MersenneTwister(seed), d1) - @test all(isapprox.(r, r1, rtol=rtol)) - @test mean(abs2.(r .- r1)) < rtol +if isdefined(PDMats, :PDSparseMat) + @testset "Sparse MvNormalCanon random sampling" begin + n = 20 + nsamp = 100_000 + Random.seed!(1234) + + J = sprandn(n, n, 0.25) + J = J'J + I + Σ = inv(Matrix(J)) + J = PDSparseMat(J) + μ = zeros(n) + + d_prec_sparse = MvNormalCanon(μ, J*μ, J) + d_prec_dense = MvNormalCanon(μ, J*μ, PDMat(Matrix(J))) + d_cov_dense = MvNormal(μ, PDMat(Symmetric(Σ))) + + x_prec_sparse = rand(d_prec_sparse, nsamp) + x_prec_dense = rand(d_prec_dense, nsamp) + x_cov_dense = rand(d_cov_dense, nsamp) + + dists = [d_prec_sparse, d_prec_dense, d_cov_dense] + samples = [x_prec_sparse, x_prec_dense, x_cov_dense] + tol = 1e-16 + se = sqrt.(diag(Σ) ./ nsamp) + #= + The cholesky decomposition of sparse matrices is performed by `SuiteSparse.CHOLMOD`, + which returns a different decomposition than the `Base.LinearAlgebra` function (which uses + LAPACK). These different Cholesky routines produce different factorizations (since the + Cholesky factorization is not in general unique). As a result, the random samples from + an `MvNormalCanon` distribution with a sparse precision matrix are not in general + identical to those from an `MvNormalCanon` or `MvNormal`, even if the seeds are + identical. As a result, these tests only check for approximate statistical equality, + rather than strict numerical equality of the samples. + =# + for i in 1:3, j in 1:3 + @test all(abs.(mean(samples[i]) .- μ) .< 2se) + loglik_ii = [logpdf(dists[i], samples[i][:, k]) for k in 1:100_000] + loglik_ji = [logpdf(dists[j], samples[i][:, k]) for k in 1:100_000] + # test average likelihood ratio between distribution i and sample j are small + @test mean((loglik_ii .- loglik_ji).^2) < tol + end + end end + ##### MLE # a slow but safe way to implement MLE for verification