From e6dd4ef8460fb045ddafc0893710983c9266412a Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Fri, 23 Jun 2023 18:45:01 +0300 Subject: [PATCH] Add `vectorize` method for `LKJCholesky` (#485) * using `LinearAlgebra.Cholesky` * add `vectorize` for `LKJCholesky` * add `vectorize` test * add forgotten `end` * Update test/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix typo * add `reconstruct` methods for LKJ/LKJCholesky inv bijectors * bump patch * bump Bijectors compat * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add Bijectors v0.13 compat * add `inittrans` method for `CholeskyVariate` * add `LKJ`/`LKJCholesky` tests Co-authored-by: torfjelde * include tests * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * make tests more accurate * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde * test `LKJCholesky` for both `'U'` and `'L'` * remove unnecessary `float` wrap * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde --- Project.toml | 2 +- src/DynamicPPL.jl | 2 ++ src/abstract_varinfo.jl | 2 +- src/utils.jl | 15 ++++++++++- test/lkj.jl | 55 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ test/utils.jl | 6 +++++ 7 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 test/lkj.jl diff --git a/Project.toml b/Project.toml index 677598eb3..eb03e7c7c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.0" +version = "0.23.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8904cfe81..04b08fb19 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -15,6 +15,8 @@ using Setfield: Setfield using ZygoteRules: ZygoteRules using LogDensityProblems: LogDensityProblems +using LinearAlgebra: Cholesky + using DocStringExtensions using Random: Random diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 116890d7b..c4cdda5a2 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -570,7 +570,7 @@ end # NOTE: `reconstruct` is no-op if `val` is already of correct shape. """ reconstruct_and_link(dist, val) - reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val) + reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val) Return linked `val` but reconstruct before linking, if necessary. diff --git a/src/utils.jl b/src/utils.jl index 525be61de..b1076daf4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,6 +213,7 @@ vectorize(d, r) = vec(r) vectorize(d::UnivariateDistribution, r::Real) = [r] vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) +vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL)) # NOTE: # We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real. @@ -235,6 +236,13 @@ reconstruct(f, dist, val) = reconstruct(dist, val) reconstruct(::UnivariateDistribution, val::Real) = val reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) +reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) +function reconstruct( + ::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector +) + return copy(val) +end + # TODO: Implement no-op `reconstruct` for general array variates. reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) @@ -294,7 +302,12 @@ function inittrans(rng, dist::MatrixDistribution) sz = Bijectors.output_size(b, size(dist)) return Bijectors.invlink(dist, randrealuni(rng, sz...)) end - +function inittrans(rng, dist::Distribution{CholeskyVariate}) + # Get the size of the unconstrained vector + b = link_transform(dist) + sz = Bijectors.output_size(b, size(dist)) + return Bijectors.invlink(dist, randrealuni(rng, sz...)) +end ################################ # Multi-sample initialisations # ################################ diff --git a/test/lkj.jl b/test/lkj.jl new file mode 100644 index 000000000..4fe8a83af --- /dev/null +++ b/test/lkj.jl @@ -0,0 +1,55 @@ +using Bijectors: pd_from_upper, pd_from_lower + +function pd_from_triangular(X::AbstractMatrix, uplo::Char) + return uplo == 'U' ? pd_from_upper(X) : pd_from_lower(X) +end + +@model lkj_prior_demo() = x ~ LKJ(2, 1) +@model lkj_chol_prior_demo(uplo) = x ~ LKJCholesky(2, 1, uplo) + +# Same for both distributions +target_mean = vec(Matrix{Float64}(I, 2, 2)) + +_lkj_atol = 0.05 + +@testset "Sample from x ~ LKJ(2, 1)" begin + model = lkj_prior_demo() + # `SampleFromPrior` will sample in constrained space. + @testset "SampleFromPrior" begin + samples = sample(model, SampleFromPrior(), 1_000) + @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = + _lkj_atol + end + + # `SampleFromUniform` will sample in unconstrained space. + @testset "SampleFromUniform" begin + samples = sample(model, SampleFromUniform(), 1_000) + @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = + _lkj_atol + end +end + +@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] + model = lkj_chol_prior_demo(uplo) + # `SampleFromPrior` will sample in unconstrained space. + @testset "SampleFromPrior" begin + samples = sample(model, SampleFromPrior(), 1_000) + # Build correlation matrix from factor + corr_matrices = map(samples) do s + M = reshape(s.metadata.vals, (2, 2)) + pd_from_triangular(M, uplo) + end + @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol + end + + # `SampleFromUniform` will sample in unconstrained space. + @testset "SampleFromUniform" begin + samples = sample(model, SampleFromUniform(), 1_000) + # Build correlation matrix from factor + corr_matrices = map(samples) do s + M = reshape(s.metadata.vals, (2, 2)) + pd_from_triangular(M, uplo) + end + @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 27889b5e5..b4099cbc7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,8 @@ include("test_util.jl") include("serialization.jl") include("loglikelihoods.jl") + + include("lkj.jl") end @testset "compat" begin diff --git a/test/utils.jl b/test/utils.jl index 37f1aaa86..1fcf09ef1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -42,4 +42,10 @@ @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing end + + @testset "vectorize" begin + dist = LKJCholesky(2, 1) + x = rand(dist) + @test vectorize(dist, x) == vec(x.UL) + end end