From 997f2bbdc7d40982ec0a90e9aba7d0124b78bb52 Mon Sep 17 00:00:00 2001 From: Rory Finnegan Date: Wed, 13 Jan 2021 03:09:18 -0600 Subject: [PATCH] MvNormal and MvTDist 32-bit logpdf fixes (#1250) logpdf should probably return 32-bit results if only 32-bit inputs are provided. This only fixes it for MvNormal and MvTDist since that's what I'm working with. Not sure if my tests fit with the current organizational system. --- src/multivariate/mvnormal.jl | 2 +- src/multivariate/mvtdist.jl | 8 +++++--- test/mvnormal.jl | 14 ++++++++++++++ test/mvtdist.jl | 3 +++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 0e1b033a8..9117ea900 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -96,7 +96,7 @@ function entropy(d::AbstractMvNormal) (length(d) * (T(log2π) + one(T)) + ldcd)/2 end -mvnormal_c0(g::AbstractMvNormal) = -(length(g) * Float64(log2π) + logdetcov(g))/2 +mvnormal_c0(g::AbstractMvNormal) = -(length(g) * convert(eltype(g), log2π) + logdetcov(g))/2 """ invcov(d::AbstractMvNormal) diff --git a/src/multivariate/mvtdist.jl b/src/multivariate/mvtdist.jl index c631410e8..edc9cec64 100644 --- a/src/multivariate/mvtdist.jl +++ b/src/multivariate/mvtdist.jl @@ -123,10 +123,12 @@ sqmahal(d::AbstractMvTDist, x::AbstractMatrix{T}) where {T<:Real} = sqmahal!(Vec function mvtdist_consts(d::AbstractMvTDist) - hdf = 0.5 * d.df - hdim = 0.5 * d.dim + H = convert(eltype(d), 0.5) + logpi = convert(eltype(d), log(pi)) + hdf = H * d.df + hdim = H * d.dim shdfhdim = hdf + hdim - v = loggamma(shdfhdim) - loggamma(hdf) - hdim*log(d.df) - hdim*log(pi) - 0.5*logdet(d.Σ) + v = loggamma(shdfhdim) - loggamma(hdf) - hdim*log(d.df) - hdim*logpi - H*logdet(d.Σ) return (shdfhdim, v) end diff --git a/test/mvnormal.jl b/test/mvnormal.jl index be3d590e0..8b2740afa 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -180,6 +180,20 @@ end @test MvNormal(mu, 0.25f0 * I) === MvNormal(mu, 0.5) end +@testset "MvNormal 32-bit logpdf" begin + # Test 32-bit logpdf + mu = [1., 2., 3.] + C = [4. -2. -1.; -2. 5. -1.; -1. -1. 6.] + d = MvNormal(mu, PDMat(C)) + X = [1., 2., 3.] + + d32 = convert(MvNormal{Float32}, d) + X32 = convert(AbstractArray{Float32}, X) + + @test isa(logpdf(d32, X32), Float32) + @test logpdf(d32, X32) ≈ logpdf(d, X) +end + ##### Random sampling from MvNormalCanon with sparse precision matrix if isdefined(PDMats, :PDSparseMat) @testset "Sparse MvNormalCanon random sampling" begin diff --git a/test/mvtdist.jl b/test/mvtdist.jl index 8d1ffe893..4e1b20324 100644 --- a/test/mvtdist.jl +++ b/test/mvtdist.jl @@ -55,8 +55,11 @@ mu_static = @SVector [1., 2] for i in 1:length(df) d = GenericMvTDist(df[i], mu_static, PDMat(Sigma)) + d32 = convert(GenericMvTDist{Float32}, d) @test d.μ isa SVector @test isapprox(logpdf(d, [-2., 3]), rvalues[i], atol=1.0e-8) + @test isa(logpdf(d32, [-2f0, 3f0]), Float32) + @test isapprox(logpdf(d32, [-2f0, 3f0]), convert(Float32, rvalues[i]), atol=1.0e-4) dd = typeof(d)(params(d)...) @test d.df == dd.df @test d.μ == dd.μ