diff --git a/Project.toml b/Project.toml index d37f3196f..a7d365d9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.16.5" +version = "0.16.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -34,7 +34,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "3.2" -AdvancedHMC = "0.2.24" +AdvancedHMC = "0.3.0" AdvancedMH = "0.6" AdvancedPS = "0.2.4" AdvancedVI = "0.1" diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 6f1507805..1d4fa6f64 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -126,7 +126,7 @@ function gibbs_state( state.z.θ .= θ_old z = state.z - return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor) + return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor) end """ diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 5b0b63418..3f614d7aa 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -4,14 +4,14 @@ struct HMCState{ TV<:AbstractVarInfo, - TTraj<:AHMC.AbstractTrajectory, + TKernel<:AHMC.HMCKernel, THam<:AHMC.Hamiltonian, PhType<:AHMC.PhasePoint, TAdapt<:AHMC.Adaptation.AbstractAdaptor, } vi::TV i::Int - traj::TTraj + kernel::TKernel hamiltonian::THam z::PhType adaptor::TAdapt @@ -190,18 +190,18 @@ function DynamicPPL.initialstep( ϵ = spl.alg.ϵ end - # Generate a trajectory. - traj = gen_traj(spl.alg, ϵ) + # Generate a kernel. + kernel = make_ahmc_kernel(spl.alg, ϵ) # Create initial transition and state. # Already perform one step since otherwise we don't get any statistics. - t = AHMC.step(rng, hamiltonian, traj, z) + t = AHMC.transition(rng, hamiltonian, kernel, z) # Adaptation adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ) if spl.alg isa AdaptiveHamiltonian - hamiltonian, traj, _ = - AHMC.adapt!(hamiltonian, traj, adaptor, + hamiltonian, kernel, _ = + AHMC.adapt!(hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate) end @@ -215,7 +215,7 @@ function DynamicPPL.initialstep( end transition = HMCTransition(vi, t) - state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor) + state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) return transition, state end @@ -234,16 +234,16 @@ function AbstractMCMC.step( # Compute transition. hamiltonian = state.hamiltonian z = state.z - t = AHMC.step(rng, hamiltonian, state.traj, z) + t = AHMC.transition(rng, hamiltonian, state.kernel, z) # Adaptation i = state.i + 1 if spl.alg isa AdaptiveHamiltonian - hamiltonian, traj, _ = - AHMC.adapt!(hamiltonian, state.traj, state.adaptor, + hamiltonian, kernel, _ = + AHMC.adapt!(hamiltonian, state.kernel, state.adaptor, i, nadapts, t.z.θ, t.stat.acceptance_rate) else - traj = state.traj + kernel = state.kernel end # Update variables @@ -255,7 +255,7 @@ function AbstractMCMC.step( # Compute next transition and state. transition = HMCTransition(vi, t) - newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) return transition, newstate end @@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) end -gen_traj(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog) -gen_traj(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ) -gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max) +make_ahmc_kernel(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog) +make_ahmc_kernel(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ) +make_ahmc_kernel(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max) #### #### Compiler interface, i.e. tilde operators. @@ -584,8 +584,8 @@ function HMCState( ϵ = spl.alg.ϵ end - # Generate a trajectory. - traj = gen_traj(spl.alg, ϵ) + # Generate a kernel. + kernel = make_ahmc_kernel(spl.alg, ϵ) # Generate a phasepoint. Replaced during sample_init! h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. @@ -593,5 +593,5 @@ function HMCState( # Unlink everything. invlink!(vi, spl) - return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) + return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) end diff --git a/test/Project.toml b/test/Project.toml index efb9fa08c..6d7ba0b91 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,7 +34,7 @@ AdvancedPS = "0.2" AdvancedVI = "0.1" Clustering = "0.14" CmdStan = "6.0.8" -Distributions = "0.23.8, 0.24, 0.25" +Distributions = "< 0.25.11" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.12" diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 52e2e1a76..6d6955664 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -1,4 +1,4 @@ -@testset "io.jl" begin +@testset "inference.jl" begin # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index a9425b630..9077aecbe 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -50,19 +50,10 @@ chain = sample(gdemo(1.5, 2.0), alg, 5_000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) - setadsafe(true) - Random.seed!(200) gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2)) chain = sample(MoGtest_default, gibbs, 5_000) check_MoGtest_default(chain, atol=0.15) - - setadsafe(false) - - Random.seed!(200) - gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) - chain = sample(MoGtest_default, gibbs, 5_000) - check_MoGtest_default(chain, atol=0.1) end @turing_testset "transitions" begin diff --git a/test/runtests.jl b/test/runtests.jl index f9fa2cb70..0fa857dee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,22 +46,27 @@ include("test_utils/AllUtils.jl") include("core/ad.jl") end + @testset "samplers (without AD)" begin + include("inference/AdvancedSMC.jl") + include("inference/emcee.jl") + include("inference/ess.jl") + include("inference/is.jl") + end + Turing.setrdcache(false) for adbackend in (:forwarddiff, :tracker, :reversediff) Turing.setadbackend(adbackend) + @info "Testing $(adbackend)" + start = time() @testset "inference: $adbackend" begin @testset "samplers" begin include("inference/gibbs.jl") include("inference/gibbs_conditional.jl") include("inference/hmc.jl") - include("inference/is.jl") - include("inference/mh.jl") - include("inference/ess.jl") - include("inference/emcee.jl") - include("inference/AdvancedSMC.jl") include("inference/Inference.jl") include("contrib/inference/dynamichmc.jl") include("contrib/inference/sghmc.jl") + include("inference/mh.jl") end end @@ -72,6 +77,11 @@ include("test_utils/AllUtils.jl") @testset "modes" begin include("modes/ModeEstimation.jl") end + + # Useful for + # a) discovering performance regressions, + # b) figuring out why CI is timing out. + @info "Tests for $(adbackend) took $(time() - start) seconds" end @testset "variational optimisers" begin include("variational/optimisers.jl")