From ae1489046fd8542e62a2e4a5dc20299c4f14dec1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 17:12:27 +0100 Subject: [PATCH 1/4] use immutable link in the initialstep for HMC --- src/mcmc/hmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 4e6733938..c80679235 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -132,13 +132,13 @@ function DynamicPPL.initialstep( rng::AbstractRNG, model::AbstractModel, spl::Sampler{<:Hamiltonian}, - vi::AbstractVarInfo; + vi_original::AbstractVarInfo; init_params=nothing, nadapts=0, kwargs... ) # Transform the samples to unconstrained space and compute the joint log probability. - vi = link!!(vi, spl, model) + vi = DynamicPPL.link(vi_original, spl, model) # Extract parameters. theta = vi[spl] From ed03952fd1dc59b2fa0b0bc9de6736c5fc5fc32f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 17:19:04 +0100 Subject: [PATCH 2/4] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5ccbbf752..2ef9edc8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.29.2" +version = "0.29.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 9423bf35e50004f8c4b6424705a24d0ef1b74f01 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 17:22:34 +0100 Subject: [PATCH 3/4] added test --- test/mcmc/hmc.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 52aff59e9..f462ff6a6 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -246,4 +246,15 @@ sample(demo_warn_init_params(), NUTS(), 5) end end + + @turing_testset "(partially) issue: #2095" begin + @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} + xs = Vector{TV}(undef, 2) + xs[1] ~ Dirichlet(ones(5)) + xs[2] ~ Dirichlet(ones(5)) + end + model = vector_of_dirichlet() + chain = sample(model, NUTS(), 1000) + @test mean(Array(chain)) ≈ 0.2 + end end From 0aa9b8aa3e2fb866641dac70bf97d2f3746f6ef2 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:42:44 +0000 Subject: [PATCH 4/4] Update hmc.jl --- src/mcmc/hmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index fa3d5f95f..a2d70e34b 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -131,7 +131,7 @@ function DynamicPPL.initialstep( rng::AbstractRNG, model::AbstractModel, spl::Sampler{<:Hamiltonian}, - vi::AbstractVarInfo; + vi_original::AbstractVarInfo; initial_params=nothing, nadapts=0, kwargs...