-
Notifications
You must be signed in to change notification settings - Fork 421
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create MvLogitNormal * Add MvLogitNormal to docs * Simplify constructors * Fix conversions * Rearrange code * Fix computation of -Inf * Add meanform and canonform * Add back type constructor * Add MvLogitNormal tests * Update and test show method * Fix testset name * Fix for older Julia versions * Restrict testing of `show` method to newer versions * Add kldivergence tests * Improve documentation * Remove constructor with type and AbstractMvNormal params * Update show method * Update docstring * Remove reference to Dirichlet * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
- Loading branch information
Showing
6 changed files
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,7 @@ export | |
Logistic, | ||
LogNormal, | ||
LogUniform, | ||
MvLogitNormal, | ||
LogitNormal, | ||
MatrixBeta, | ||
MatrixFDist, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
""" | ||
MvLogitNormal{<:AbstractMvNormal} | ||
The [multivariate logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization) | ||
is a multivariate generalization of [`LogitNormal`](@ref) capable of handling correlations | ||
between variables. | ||
If ``\\mathbf{y} \\sim \\mathrm{MvNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})`` is a | ||
length ``d-1`` vector, then | ||
```math | ||
\\mathbf{x} = \\operatorname{softmax}\\left(\\begin{bmatrix}\\mathbf{y} \\\\ 0 \\end{bmatrix}\\right) \\sim \\mathrm{MvLogitNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma}) | ||
``` | ||
is a length ``d`` probability vector. | ||
```julia | ||
MvLogitNormal(μ, Σ) # MvLogitNormal with y ~ MvNormal(μ, Σ) | ||
MvLogitNormal(MvNormal(μ, Σ)) # same as above | ||
MvLogitNormal(MvNormalCanon(μ, J)) # MvLogitNormal with y ~ MvNormalCanon(μ, J) | ||
``` | ||
# Fields | ||
- `normal::AbstractMvNormal`: contains the ``d-1``-dimensional distribution of ``y`` | ||
""" | ||
struct MvLogitNormal{D<:AbstractMvNormal} <: ContinuousMultivariateDistribution | ||
normal::D | ||
MvLogitNormal{D}(normal::D) where {D<:AbstractMvNormal} = new{D}(normal) | ||
end | ||
MvLogitNormal(d::AbstractMvNormal) = MvLogitNormal{typeof(d)}(d) | ||
MvLogitNormal(args...) = MvLogitNormal(MvNormal(args...)) | ||
|
||
function Base.show(io::IO, d::MvLogitNormal; indent::String=" ") | ||
print(io, distrname(d)) | ||
println(io, "(") | ||
normstr = strip(sprint(show, d.normal; context=IOContext(io))) | ||
normstr = replace(normstr, "\n" => "\n$indent") | ||
print(io, indent) | ||
println(io, normstr) | ||
println(io, ")") | ||
end | ||
|
||
# Conversions | ||
|
||
function convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal) where {D} | ||
return MvLogitNormal(convert(D, d.normal)) | ||
end | ||
Base.convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal{D}) where {D} = d | ||
|
||
meanform(d::MvLogitNormal{<:MvNormalCanon}) = MvLogitNormal(meanform(d.normal)) | ||
canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal)) | ||
|
||
# Properties | ||
|
||
length(d::MvLogitNormal) = length(d.normal) + 1 | ||
Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D) | ||
Base.eltype(d::MvLogitNormal) = eltype(d.normal) | ||
params(d::MvLogitNormal) = params(d.normal) | ||
@inline partype(d::MvLogitNormal) = partype(d.normal) | ||
|
||
location(d::MvLogitNormal) = mean(d.normal) | ||
minimum(d::MvLogitNormal) = fill(zero(eltype(d)), length(d)) | ||
maximum(d::MvLogitNormal) = fill(oneunit(eltype(d)), length(d)) | ||
|
||
function insupport(d::MvLogitNormal, x::AbstractVector{<:Real}) | ||
return length(d) == length(x) && all(≥(0), x) && sum(x) ≈ 1 | ||
end | ||
|
||
# Evaluation | ||
|
||
function _logpdf(d::MvLogitNormal, x::AbstractVector{<:Real}) | ||
if !insupport(d, x) | ||
return oftype(logpdf(d.normal, _inv_softmax1(abs.(x))), -Inf) | ||
else | ||
return logpdf(d.normal, _inv_softmax1(x)) - sum(log, x) | ||
end | ||
end | ||
|
||
function gradlogpdf(d::MvLogitNormal, x::AbstractVector{<:Real}) | ||
y = _inv_softmax1(x) | ||
∂y = gradlogpdf(d.normal, y) | ||
∂x = (vcat(∂y, -sum(∂y)) .- 1) ./ x | ||
return ∂x | ||
end | ||
|
||
# Statistics | ||
|
||
kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.normal) | ||
|
||
# Sampling | ||
|
||
function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real}) | ||
y = @views _drop1(x) | ||
rand!(rng, d.normal, y) | ||
_softmax1!(x, y) | ||
return x | ||
end | ||
|
||
# Fitting | ||
|
||
function fit_mle(::Type{MvLogitNormal{D}}, x::AbstractMatrix{<:Real}; kwargs...) where {D} | ||
y = similar(x, size(x, 1) - 1, size(x, 2)) | ||
map(_inv_softmax1!, eachcol(y), eachcol(x)) | ||
normal = fit_mle(D, y; kwargs...) | ||
return MvLogitNormal(normal) | ||
end | ||
function fit_mle(::Type{MvLogitNormal}, x::AbstractMatrix{<:Real}; kwargs...) | ||
return fit_mle(MvLogitNormal{MvNormal}, x; kwargs...) | ||
end | ||
|
||
# Utility | ||
|
||
function _softmax1!(x::AbstractVector, y::AbstractVector) | ||
u = max(0, maximum(y)) | ||
_drop1(x) .= exp.(y .- u) | ||
x[end] = exp(-u) | ||
LinearAlgebra.normalize!(x, 1) | ||
return x | ||
end | ||
function _softmax1!(x::AbstractMatrix, y::AbstractMatrix) | ||
map(_softmax1!, eachcol(x), eachcol(y)) | ||
return x | ||
end | ||
|
||
_drop1(x::AbstractVector) = @views x[firstindex(x, 1):(end - 1)] | ||
_drop1(x::AbstractMatrix) = @views x[firstindex(x, 1):(end - 1), :] | ||
|
||
_last1(x::AbstractVector) = x[end] | ||
_last1(x::AbstractMatrix) = @views x[end, :] | ||
|
||
function _inv_softmax1!(y::AbstractVecOrMat, x::AbstractVecOrMat) | ||
x₋ = _drop1(x) | ||
xd = _last1(x) | ||
@. y = log(x₋) - log(xd) | ||
return y | ||
end | ||
function _inv_softmax1(x::AbstractVecOrMat) | ||
y = similar(_drop1(x)) | ||
_inv_softmax1!(y, x) | ||
return y | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Tests on Multivariate Logit-Normal distributions | ||
using Distributions | ||
using ForwardDiff | ||
using LinearAlgebra | ||
using Random | ||
using Test | ||
|
||
####### Core testing procedure | ||
|
||
function test_mvlogitnormal(d::MvLogitNormal; nsamples::Int=10^6) | ||
@test d.normal isa AbstractMvNormal | ||
dnorm = d.normal | ||
|
||
@testset "properties" begin | ||
@test length(d) == length(dnorm) + 1 | ||
@test params(d) == params(dnorm) | ||
@test partype(d) == partype(dnorm) | ||
@test eltype(d) == eltype(dnorm) | ||
@test eltype(typeof(d)) == eltype(typeof(dnorm)) | ||
@test location(d) == mean(dnorm) | ||
@test minimum(d) == fill(0, length(d)) | ||
@test maximum(d) == fill(1, length(d)) | ||
@test insupport(d, normalize(rand(length(d)), 1)) | ||
@test !insupport(d, normalize(rand(length(d) + 1), 1)) | ||
@test !insupport(d, rand(length(d))) | ||
x = rand(length(d) - 1) | ||
x = vcat(x, -sum(x)) | ||
@test !insupport(d, x) | ||
end | ||
|
||
@testset "conversions" begin | ||
@test convert(typeof(d), d) === d | ||
T = partype(d) <: Float64 ? Float32 : Float64 | ||
if dnorm isa MvNormal | ||
@test convert(MvLogitNormal{MvNormal{T}}, d).normal == | ||
convert(MvNormal{T}, dnorm) | ||
@test partype(convert(MvLogitNormal{MvNormal{T}}, d)) <: T | ||
@test canonform(d) isa MvLogitNormal{<:MvNormalCanon} | ||
@test canonform(d).normal == canonform(dnorm) | ||
elseif dnorm isa MvNormalCanon | ||
@test convert(MvLogitNormal{MvNormalCanon{T}}, d).normal == | ||
convert(MvNormalCanon{T}, dnorm) | ||
@test partype(convert(MvLogitNormal{MvNormalCanon{T}}, d)) <: T | ||
@test meanform(d) isa MvLogitNormal{<:MvNormal} | ||
@test meanform(d).normal == meanform(dnorm) | ||
end | ||
end | ||
|
||
@testset "sampling" begin | ||
X = rand(d, nsamples) | ||
Y = @views log.(X[1:(end - 1), :]) .- log.(X[end, :]') | ||
Ymean = vec(mean(Y; dims=2)) | ||
Ycov = cov(Y; dims=2) | ||
for i in 1:(length(d) - 1) | ||
@test isapprox( | ||
Ymean[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8 | ||
) | ||
end | ||
for i in 1:(length(d) - 1), j in 1:(length(d) - 1) | ||
@test isapprox( | ||
Ycov[i, j], | ||
cov(dnorm)[i, j], | ||
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20, | ||
) | ||
end | ||
end | ||
|
||
@testset "fitting" begin | ||
X = rand(d, nsamples) | ||
dfit = fit_mle(MvLogitNormal, X) | ||
dfit_norm = dfit.normal | ||
for i in 1:(length(d) - 1) | ||
@test isapprox( | ||
mean(dfit_norm)[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8 | ||
) | ||
end | ||
for i in 1:(length(d) - 1), j in 1:(length(d) - 1) | ||
@test isapprox( | ||
cov(dfit_norm)[i, j], | ||
cov(dnorm)[i, j], | ||
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20, | ||
) | ||
end | ||
@test fit_mle(MvLogitNormal{IsoNormal}, X) isa MvLogitNormal{<:IsoNormal} | ||
end | ||
|
||
@testset "evaluation" begin | ||
X = rand(d, nsamples) | ||
for i in 1:min(100, nsamples) | ||
@test @inferred(logpdf(d, X[:, i])) ≈ log(pdf(d, X[:, i])) | ||
if dnorm isa MvNormal | ||
@test @inferred(gradlogpdf(d, X[:, i])) ≈ | ||
ForwardDiff.gradient(x -> logpdf(d, x), X[:, i]) | ||
end | ||
end | ||
@test logpdf(d, X) ≈ log.(pdf(d, X)) | ||
@test isequal(logpdf(d, zeros(length(d))), -Inf) | ||
@test isequal(logpdf(d, ones(length(d))), -Inf) | ||
@test isequal(pdf(d, zeros(length(d))), 0) | ||
@test isequal(pdf(d, ones(length(d))), 0) | ||
end | ||
end | ||
|
||
@testset "Results MvLogitNormal consistent with univariate LogitNormal" begin | ||
μ = randn() | ||
σ = rand() | ||
d = MvLogitNormal([μ], fill(σ^2, 1, 1)) | ||
duni = LogitNormal(μ, σ) | ||
@test location(d) ≈ [location(duni)] | ||
x = normalize(rand(2), 1) | ||
@test logpdf(d, x) ≈ logpdf(duni, x[1]) | ||
@test pdf(d, x) ≈ pdf(duni, x[1]) | ||
@test (Random.seed!(9274); rand(d)[1]) ≈ (Random.seed!(9274); rand(duni)) | ||
end | ||
|
||
###### General Testing | ||
|
||
@testset "MvLogitNormal tests" begin | ||
mvnorm_params = [ | ||
(randn(5), I * rand()), | ||
(randn(4), Diagonal(rand(4))), | ||
(Diagonal(rand(6)),), | ||
(randn(5), exp(Symmetric(randn(5, 5)))), | ||
(exp(Symmetric(randn(5, 5))),), | ||
] | ||
@testset "wraps MvNormal" begin | ||
@testset "$(typeof(prms))" for prms in mvnorm_params | ||
d = MvLogitNormal(prms...) | ||
@test d == MvLogitNormal(MvNormal(prms...)) | ||
test_mvlogitnormal(d; nsamples=10^4) | ||
end | ||
end | ||
@testset "wraps MvNormalCanon" begin | ||
@testset "$(typeof(prms))" for prms in mvnorm_params | ||
d = MvLogitNormal(MvNormalCanon(prms...)) | ||
test_mvlogitnormal(d; nsamples=10^4) | ||
end | ||
end | ||
|
||
@testset "kldivergence" begin | ||
d1 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5)))) | ||
d2 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5)))) | ||
@test kldivergence(d1, d2) ≈ kldivergence(d1.normal, d2.normal) | ||
end | ||
|
||
VERSION ≥ v"1.8" && @testset "show" begin | ||
d = MvLogitNormal([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0])) | ||
@test sprint(show, d) === """ | ||
MvLogitNormal{DiagNormal}( | ||
DiagNormal( | ||
dim: 3 | ||
μ: [1.0, 2.0, 3.0] | ||
Σ: [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0] | ||
) | ||
) | ||
""" | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters