Skip to content

Commit

Permalink
MvNormal and MvTDist 32-bit logpdf fixes (#1250)
Browse files Browse the repository at this point in the history
logpdf should probably return 32-bit results if only 32-bit inputs are provided.
This only fixes it for MvNormal and MvTDist since that's what I'm working with.
Not sure if my tests fit with the current organizational system.
  • Loading branch information
rofinn authored Jan 13, 2021
1 parent ccebbd7 commit 997f2bb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function entropy(d::AbstractMvNormal)
(length(d) * (T(log2π) + one(T)) + ldcd)/2
end

mvnormal_c0(g::AbstractMvNormal) = -(length(g) * Float64(log2π) + logdetcov(g))/2
mvnormal_c0(g::AbstractMvNormal) = -(length(g) * convert(eltype(g), log2π) + logdetcov(g))/2

"""
invcov(d::AbstractMvNormal)
Expand Down
8 changes: 5 additions & 3 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,12 @@ sqmahal(d::AbstractMvTDist, x::AbstractMatrix{T}) where {T<:Real} = sqmahal!(Vec


function mvtdist_consts(d::AbstractMvTDist)
hdf = 0.5 * d.df
hdim = 0.5 * d.dim
H = convert(eltype(d), 0.5)
logpi = convert(eltype(d), log(pi))
hdf = H * d.df
hdim = H * d.dim
shdfhdim = hdf + hdim
v = loggamma(shdfhdim) - loggamma(hdf) - hdim*log(d.df) - hdim*log(pi) - 0.5*logdet(d.Σ)
v = loggamma(shdfhdim) - loggamma(hdf) - hdim*log(d.df) - hdim*logpi - H*logdet(d.Σ)
return (shdfhdim, v)
end

Expand Down
14 changes: 14 additions & 0 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,20 @@ end
@test MvNormal(mu, 0.25f0 * I) === MvNormal(mu, 0.5)
end

@testset "MvNormal 32-bit logpdf" begin
# Test 32-bit logpdf
mu = [1., 2., 3.]
C = [4. -2. -1.; -2. 5. -1.; -1. -1. 6.]
d = MvNormal(mu, PDMat(C))
X = [1., 2., 3.]

d32 = convert(MvNormal{Float32}, d)
X32 = convert(AbstractArray{Float32}, X)

@test isa(logpdf(d32, X32), Float32)
@test logpdf(d32, X32) logpdf(d, X)
end

##### Random sampling from MvNormalCanon with sparse precision matrix
if isdefined(PDMats, :PDSparseMat)
@testset "Sparse MvNormalCanon random sampling" begin
Expand Down
3 changes: 3 additions & 0 deletions test/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ mu_static = @SVector [1., 2]

for i in 1:length(df)
d = GenericMvTDist(df[i], mu_static, PDMat(Sigma))
d32 = convert(GenericMvTDist{Float32}, d)
@test d.μ isa SVector
@test isapprox(logpdf(d, [-2., 3]), rvalues[i], atol=1.0e-8)
@test isa(logpdf(d32, [-2f0, 3f0]), Float32)
@test isapprox(logpdf(d32, [-2f0, 3f0]), convert(Float32, rvalues[i]), atol=1.0e-4)
dd = typeof(d)(params(d)...)
@test d.df == dd.df
@test d.μ == dd.μ
Expand Down

0 comments on commit 997f2bb

Please sign in to comment.