Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MvLogitNormal #1774

Merged
merged 21 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Multinomial
Distributions.AbstractMvNormal
MvNormal
MvNormalCanon
MvLogitNormal
MvLogNormal
Dirichlet
Product
Expand Down
1 change: 1 addition & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export
Logistic,
LogNormal,
LogUniform,
MvLogitNormal,
LogitNormal,
MatrixBeta,
MatrixFDist,
Expand Down
140 changes: 140 additions & 0 deletions src/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
MvLogitNormal{<:AbstractMvNormal}

The [multivariate logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization)
is a multivariate generalization of [`LogitNormal`](@ref) capable of handling correlations
between variables.

If ``\\mathbf{y} \\sim \\mathrm{MvNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})`` is a
length ``d-1`` vector, then
```math
\\mathbf{x} = \\operatorname{softmax}\\left(\\begin{bmatrix}\\mathbf{y} \\\\ 0 \\end{bmatrix}\\right) \\sim \\mathrm{MvLogitNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})
```
is a length ``d`` probability vector.

```julia
MvLogitNormal(μ, Σ) # MvLogitNormal with y ~ MvNormal(μ, Σ)
MvLogitNormal(MvNormal(μ, Σ)) # same as above
MvLogitNormal(MvNormalCanon(μ, J)) # MvLogitNormal with y ~ MvNormalCanon(μ, J)
```

# Fields

- `normal::AbstractMvNormal`: contains the ``d-1``-dimensional distribution of ``y``
"""
struct MvLogitNormal{D<:AbstractMvNormal} <: ContinuousMultivariateDistribution
normal::D
MvLogitNormal{D}(normal::D) where {D<:AbstractMvNormal} = new{D}(normal)
end
MvLogitNormal(d::AbstractMvNormal) = MvLogitNormal{typeof(d)}(d)
MvLogitNormal(args...) = MvLogitNormal(MvNormal(args...))

function Base.show(io::IO, d::MvLogitNormal; indent::String=" ")
print(io, distrname(d))
println(io, "(")
normstr = strip(sprint(show, d.normal; context=IOContext(io)))
normstr = replace(normstr, "\n" => "\n$indent")
print(io, indent)
println(io, normstr)
println(io, ")")
end

# Conversions

function convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal) where {D}
return MvLogitNormal(convert(D, d.normal))
end
Base.convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal{D}) where {D} = d

meanform(d::MvLogitNormal{<:MvNormalCanon}) = MvLogitNormal(meanform(d.normal))
canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal))

# Properties

length(d::MvLogitNormal) = length(d.normal) + 1
Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
Base.eltype(d::MvLogitNormal) = eltype(d.normal)
params(d::MvLogitNormal) = params(d.normal)
@inline partype(d::MvLogitNormal) = partype(d.normal)

location(d::MvLogitNormal) = mean(d.normal)
minimum(d::MvLogitNormal) = fill(zero(eltype(d)), length(d))
maximum(d::MvLogitNormal) = fill(oneunit(eltype(d)), length(d))

function insupport(d::MvLogitNormal, x::AbstractVector{<:Real})
return length(d) == length(x) && all(≥(0), x) && sum(x) ≈ 1
end

# Evaluation

function _logpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
if !insupport(d, x)
return oftype(logpdf(d.normal, _inv_softmax1(abs.(x))), -Inf)
else
return logpdf(d.normal, _inv_softmax1(x)) - sum(log, x)
end
end

function gradlogpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
y = _inv_softmax1(x)
∂y = gradlogpdf(d.normal, y)
∂x = (vcat(∂y, -sum(∂y)) .- 1) ./ x
return ∂x
end

# Statistics

kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.normal)

# Sampling

function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
y = @views _drop1(x)
rand!(rng, d.normal, y)
_softmax1!(x, y)
return x
end

# Fitting

function fit_mle(::Type{MvLogitNormal{D}}, x::AbstractMatrix{<:Real}; kwargs...) where {D}
y = similar(x, size(x, 1) - 1, size(x, 2))
map(_inv_softmax1!, eachcol(y), eachcol(x))
normal = fit_mle(D, y; kwargs...)
return MvLogitNormal(normal)
end
function fit_mle(::Type{MvLogitNormal}, x::AbstractMatrix{<:Real}; kwargs...)
return fit_mle(MvLogitNormal{MvNormal}, x; kwargs...)
end

# Utility

function _softmax1!(x::AbstractVector, y::AbstractVector)
u = max(0, maximum(y))
_drop1(x) .= exp.(y .- u)
x[end] = exp(-u)
LinearAlgebra.normalize!(x, 1)
return x
end
function _softmax1!(x::AbstractMatrix, y::AbstractMatrix)
map(_softmax1!, eachcol(x), eachcol(y))
return x
end

_drop1(x::AbstractVector) = @views x[firstindex(x, 1):(end - 1)]
_drop1(x::AbstractMatrix) = @views x[firstindex(x, 1):(end - 1), :]
devmotion marked this conversation as resolved.
Show resolved Hide resolved

_last1(x::AbstractVector) = x[end]
_last1(x::AbstractMatrix) = @views x[end, :]

function _inv_softmax1!(y::AbstractVecOrMat, x::AbstractVecOrMat)
x₋ = _drop1(x)
xd = _last1(x)
@. y = log(x₋) - log(xd)
return y
end
function _inv_softmax1(x::AbstractVecOrMat)
y = similar(_drop1(x))
_inv_softmax1!(y, x)
return y
end
1 change: 1 addition & 0 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ for fname in ["dirichlet.jl",
"jointorderstatistics.jl",
"mvnormal.jl",
"mvnormalcanon.jl",
"mvlogitnormal.jl",
"mvlognormal.jl",
"mvtdist.jl",
"product.jl", # deprecated
Expand Down
158 changes: 158 additions & 0 deletions test/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Tests on Multivariate Logit-Normal distributions
using Distributions
using ForwardDiff
using LinearAlgebra
using Random
using Test

####### Core testing procedure

function test_mvlogitnormal(d::MvLogitNormal; nsamples::Int=10^6)
@test d.normal isa AbstractMvNormal
dnorm = d.normal

@testset "properties" begin
@test length(d) == length(dnorm) + 1
@test params(d) == params(dnorm)
@test partype(d) == partype(dnorm)
@test eltype(d) == eltype(dnorm)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
@test eltype(typeof(d)) == eltype(typeof(dnorm))
@test location(d) == mean(dnorm)
@test minimum(d) == fill(0, length(d))
@test maximum(d) == fill(1, length(d))
@test insupport(d, normalize(rand(length(d)), 1))
@test !insupport(d, normalize(rand(length(d) + 1), 1))
@test !insupport(d, rand(length(d)))
x = rand(length(d) - 1)
x = vcat(x, -sum(x))
@test !insupport(d, x)
end

@testset "conversions" begin
@test convert(typeof(d), d) === d
T = partype(d) <: Float64 ? Float32 : Float64
if dnorm isa MvNormal
@test convert(MvLogitNormal{MvNormal{T}}, d).normal ==
convert(MvNormal{T}, dnorm)
@test partype(convert(MvLogitNormal{MvNormal{T}}, d)) <: T
@test canonform(d) isa MvLogitNormal{<:MvNormalCanon}
@test canonform(d).normal == canonform(dnorm)
elseif dnorm isa MvNormalCanon
@test convert(MvLogitNormal{MvNormalCanon{T}}, d).normal ==
convert(MvNormalCanon{T}, dnorm)
@test partype(convert(MvLogitNormal{MvNormalCanon{T}}, d)) <: T
@test meanform(d) isa MvLogitNormal{<:MvNormal}
@test meanform(d).normal == meanform(dnorm)
end
end

@testset "sampling" begin
X = rand(d, nsamples)
Y = @views log.(X[1:(end - 1), :]) .- log.(X[end, :]')
Ymean = vec(mean(Y; dims=2))
Ycov = cov(Y; dims=2)
for i in 1:(length(d) - 1)
@test isapprox(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure anymore but I assume we don't have any statistical test utilities for e.g. hypothesis tests of samples? IIRC KS tests are used for the univariate distributions but I assume we don't have anything for multi- or matrixvariate distributions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the closest I can find, which just checks that a 1x1 matrix distribution is consistent with the univariate distribution it reduces to:

function test_draws_against_univariate_cdf(D::MatrixDistribution, d::UnivariateDistribution)
α = 0.025
M = 100000
matvardraws = [rand(D)[1] for m in 1:M]
@test pvalue_kolmogorovsmirnoff(matvardraws, d) >= α
nothing
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for checking! Would be useful to have something similar for multivariate distributions (generally, IMO we should provide test utilities such that you can check more easily that you implement the interface correctly) but better to do this in a separate PR.

Ymean[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
)
end
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
@test isapprox(
Ycov[i, j],
cov(dnorm)[i, j],
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
)
end
end

@testset "fitting" begin
X = rand(d, nsamples)
dfit = fit_mle(MvLogitNormal, X)
dfit_norm = dfit.normal
for i in 1:(length(d) - 1)
@test isapprox(
mean(dfit_norm)[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
)
end
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
@test isapprox(
cov(dfit_norm)[i, j],
cov(dnorm)[i, j],
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
)
end
@test fit_mle(MvLogitNormal{IsoNormal}, X) isa MvLogitNormal{<:IsoNormal}
end

@testset "evaluation" begin
X = rand(d, nsamples)
for i in 1:min(100, nsamples)
@test @inferred(logpdf(d, X[:, i])) ≈ log(pdf(d, X[:, i]))
if dnorm isa MvNormal
@test @inferred(gradlogpdf(d, X[:, i])) ≈
ForwardDiff.gradient(x -> logpdf(d, x), X[:, i])
end
end
@test logpdf(d, X) ≈ log.(pdf(d, X))
@test isequal(logpdf(d, zeros(length(d))), -Inf)
@test isequal(logpdf(d, ones(length(d))), -Inf)
@test isequal(pdf(d, zeros(length(d))), 0)
@test isequal(pdf(d, ones(length(d))), 0)
end
end

@testset "Results MvLogitNormal consistent with univariate LogitNormal" begin
μ = randn()
σ = rand()
d = MvLogitNormal([μ], fill(σ^2, 1, 1))
duni = LogitNormal(μ, σ)
@test location(d) ≈ [location(duni)]
x = normalize(rand(2), 1)
@test logpdf(d, x) ≈ logpdf(duni, x[1])
@test pdf(d, x) ≈ pdf(duni, x[1])
@test (Random.seed!(9274); rand(d)[1]) ≈ (Random.seed!(9274); rand(duni))
end

###### General Testing

@testset "MvLogitNormal tests" begin
mvnorm_params = [
(randn(5), I * rand()),
(randn(4), Diagonal(rand(4))),
(Diagonal(rand(6)),),
(randn(5), exp(Symmetric(randn(5, 5)))),
(exp(Symmetric(randn(5, 5))),),
]
@testset "wraps MvNormal" begin
@testset "$(typeof(prms))" for prms in mvnorm_params
d = MvLogitNormal(prms...)
@test d == MvLogitNormal(MvNormal(prms...))
test_mvlogitnormal(d; nsamples=10^4)
end
end
@testset "wraps MvNormalCanon" begin
@testset "$(typeof(prms))" for prms in mvnorm_params
d = MvLogitNormal(MvNormalCanon(prms...))
test_mvlogitnormal(d; nsamples=10^4)
end
end

@testset "kldivergence" begin
d1 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
d2 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
@test kldivergence(d1, d2) ≈ kldivergence(d1.normal, d2.normal)
end

VERSION ≥ v"1.8" && @testset "show" begin
d = MvLogitNormal([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0]))
@test sprint(show, d) === """
MvLogitNormal{DiagNormal}(
DiagNormal(
dim: 3
μ: [1.0, 2.0, 3.0]
Σ: [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0]
)
)
"""
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const tests = [
"univariate/continuous/uniform",
"univariate/continuous/lognormal",
"multivariate/mvnormal",
"multivariate/mvlogitnormal",
"multivariate/mvlognormal",
"types", # extra file compared to /src
"utils",
Expand Down