From 1ebe8bcdbd17443f06594f84c54460580a30c6b0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Aug 2023 06:13:38 +0100 Subject: [PATCH] Bugfix for linking (#513) * fixed setall for TypedVarInfo * removed unnused argument to _setall! * bumped patch version * added tests for dirichlet * added invlinked tests for existing tests * added missing linking.jl tests in test/runtests.jl * Apply suggestions from code review Co-authored-by: David Widmann --------- Co-authored-by: David Widmann --- Project.toml | 2 +- src/varinfo.jl | 4 +-- test/linking.jl | 71 ++++++++++++++++++++++++++++++++---------------- test/runtests.jl | 1 + 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 4f7507c62..f5ddb0886 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.10" +version = "0.23.11" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varinfo.jl b/src/varinfo.jl index a30c9ea24..bda979eef 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -360,11 +360,11 @@ The values may or may not be transformed to Euclidean space. """ setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) -@generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} +@generated function _setall!(metadata::NamedTuple{names}, val) where {names} expr = Expr(:block) start = :(1) for f in names - length = :(length(metadata.$f.vals)) + length = :(sum(length, metadata.$f.ranges)) finish = :($start + $length - 1) push!(expr.args, :(metadata.$f.vals .= val[($start):($finish)])) start = :($start + $length) diff --git a/test/linking.jl b/test/linking.jl index f81895788..26d28c13d 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -35,7 +35,7 @@ end Base.size(d::MyMatrixDistribution) = (d.dim, d.dim) function Distributions._rand!( - rng::AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real} + rng::Random.AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real} ) return randn!(rng, x) end @@ -58,29 +58,54 @@ function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Boo end @testset "Linking" begin - # Just making sure the transformations are okay. - x = randn(3, 3) - f = TrilToVec((3, 3)) - f_inv = inverse(f) - y = f(x) - @test y isa AbstractVector - @test f_inv(f(x)) == LowerTriangular(x) + @testset "simple matrix distribution" begin + # Just making sure the transformations are okay. + x = randn(3, 3) + f = TrilToVec((3, 3)) + f_inv = inverse(f) + y = f(x) + @test y isa AbstractVector + @test f_inv(f(x)) == LowerTriangular(x) - # Within a model. - dist = MyMatrixDistribution(3) - @model demo() = m ~ dist - model = demo() + # Within a model. + dist = MyMatrixDistribution(3) + @model demo() = m ~ dist + model = demo() - vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) - # Linked one should be working with a lower-dimensional representation. - @test length(vi_linked[:]) < length(vi[:]) - @test length(vi_linked[:]) == 3 + vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + # Evaluate once to ensure we have `logp` value. + vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + # Difference should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) + # Linked one should be working with a lower-dimensional representation. + @test length(vi_linked[:]) < length(vi[:]) + @test length(vi_linked[:]) == length(y) + # Invlinked. + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + @test length(vi_invlinked[:]) == length(vi[:]) + @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + end + end + + @testset "dirichlet" begin + @model demo_dirichlet() = x ~ Dirichlet(2, 1.0) + model = demo_dirichlet() + vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + @test length(vi[:]) == 2 + @test iszero(getlogp(vi)) + # Linked. + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + @test length(vi_linked[:]) == 1 + @test !iszero(getlogp(vi_linked)) # should now include the log-absdet-jacobian correction + # Invlinked. + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + @test length(vi_invlinked[:]) == 2 + @test iszero(getlogp(vi_invlinked)) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 5aa91e2f0..2d1b521ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,7 @@ include("test_util.jl") include("contexts.jl") include("context_implementations.jl") include("logdensityfunction.jl") + include("linking.jl") include("threadsafe.jl")