Skip to content

Commit

Permalink
Add vectorize method for LKJCholesky (#485)
Browse files Browse the repository at this point in the history
* using `LinearAlgebra.Cholesky`

* add `vectorize` for `LKJCholesky`

* add `vectorize` test

* add forgotten `end`

* Update test/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix typo

* add `reconstruct` methods for LKJ/LKJCholesky inv bijectors

* bump patch

* bump Bijectors compat

* Update src/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add Bijectors v0.13 compat

* add `inittrans` method for `CholeskyVariate`

* add `LKJ`/`LKJCholesky` tests
Co-authored-by: torfjelde

* include tests

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* make tests more accurate

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* Update test/lkj.jl

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* Update test/lkj.jl

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* test `LKJCholesky` for both `'U'` and `'L'`

* remove unnecessary `float` wrap

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
  • Loading branch information
3 people authored Jun 23, 2023
1 parent 5f74696 commit e6dd4ef
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.0"
version = "0.23.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky

using DocStringExtensions

using Random: Random
Expand Down
2 changes: 1 addition & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ end
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
Return linked `val` but reconstruct before linking, if necessary.
Expand Down
15 changes: 14 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))

# NOTE:
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
Expand All @@ -235,6 +236,13 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
function reconstruct(
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
)
return copy(val)
end

# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
Expand Down Expand Up @@ -294,7 +302,12 @@ function inittrans(rng, dist::MatrixDistribution)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end

function inittrans(rng, dist::Distribution{CholeskyVariate})
# Get the size of the unconstrained vector
b = link_transform(dist)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end
################################
# Multi-sample initialisations #
################################
Expand Down
55 changes: 55 additions & 0 deletions test/lkj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Bijectors: pd_from_upper, pd_from_lower

function pd_from_triangular(X::AbstractMatrix, uplo::Char)
return uplo == 'U' ? pd_from_upper(X) : pd_from_lower(X)
end

@model lkj_prior_demo() = x ~ LKJ(2, 1)
@model lkj_chol_prior_demo(uplo) = x ~ LKJCholesky(2, 1, uplo)

# Same for both distributions
target_mean = vec(Matrix{Float64}(I, 2, 2))

_lkj_atol = 0.05

@testset "Sample from x ~ LKJ(2, 1)" begin
model = lkj_prior_demo()
# `SampleFromPrior` will sample in constrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end
end

@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L']
model = lkj_chol_prior_demo(uplo)
# `SampleFromPrior` will sample in unconstrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
pd_from_triangular(M, uplo)
end
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
pd_from_triangular(M, uplo)
end
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ include("test_util.jl")
include("serialization.jl")

include("loglikelihoods.jl")

include("lkj.jl")
end

@testset "compat" begin
Expand Down
6 changes: 6 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "vectorize" begin
dist = LKJCholesky(2, 1)
x = rand(dist)
@test vectorize(dist, x) == vec(x.UL)
end
end

2 comments on commit e6dd4ef

@yebai
Copy link
Member

@yebai yebai commented on e6dd4ef Jun 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86195

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.23.1 -m "<description of version>" e6dd4ef8460fb045ddafc0893710983c9266412a
git push origin v0.23.1

Please sign in to comment.