Skip to content

Commit

Permalink
Updated HMC implementation for new AHMC version (#1660)
Browse files Browse the repository at this point in the history
* updated HMC implementation according to new AHMC interface

* bump compat bound for AdvancedHMC

* bumped patch version

* disable GMM Gibbs conditional test to see if it fixes CI

* include tests again

* dont test non-AD samplers for every AD backend

* added back a test

* added back a test

* removed some redundant tests and fixed a typo

* added macro timed_testset

* upper-bound Distributions.jl apparently fixes the test-freeze

* hyphen compat specifies arent compatible with Julia 1.3

* removed the timed_testset stuff
  • Loading branch information
torfjelde authored Jul 17, 2021
1 parent d029198 commit 7cb94d6
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.5"
version = "0.16.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -34,7 +34,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "3.2"
AdvancedHMC = "0.2.24"
AdvancedHMC = "0.3.0"
AdvancedMH = "0.6"
AdvancedPS = "0.2.4"
AdvancedVI = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function gibbs_state(
state.z.θ .= θ_old
z = state.z

return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor)
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

struct HMCState{
TV<:AbstractVarInfo,
TTraj<:AHMC.AbstractTrajectory,
TKernel<:AHMC.HMCKernel,
THam<:AHMC.Hamiltonian,
PhType<:AHMC.PhasePoint,
TAdapt<:AHMC.Adaptation.AbstractAdaptor,
}
vi::TV
i::Int
traj::TTraj
kernel::TKernel
hamiltonian::THam
z::PhType
adaptor::TAdapt
Expand Down Expand Up @@ -190,18 +190,18 @@ function DynamicPPL.initialstep(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Create initial transition and state.
# Already perform one step since otherwise we don't get any statistics.
t = AHMC.step(rng, hamiltonian, traj, z)
t = AHMC.transition(rng, hamiltonian, kernel, z)

# Adaptation
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, traj, adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, kernel, adaptor,
1, nadapts, t.z.θ, t.stat.acceptance_rate)
end

Expand All @@ -215,7 +215,7 @@ function DynamicPPL.initialstep(
end

transition = HMCTransition(vi, t)
state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)

return transition, state
end
Expand All @@ -234,16 +234,16 @@ function AbstractMCMC.step(
# Compute transition.
hamiltonian = state.hamiltonian
z = state.z
t = AHMC.step(rng, hamiltonian, state.traj, z)
t = AHMC.transition(rng, hamiltonian, state.kernel, z)

# Adaptation
i = state.i + 1
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, state.traj, state.adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, state.kernel, state.adaptor,
i, nadapts, t.z.θ, t.stat.acceptance_rate)
else
traj = state.traj
kernel = state.kernel
end

# Update variables
Expand All @@ -255,7 +255,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = HMCTransition(vi, t)
newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)

return transition, newstate
end
Expand Down Expand Up @@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
end

gen_traj(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
gen_traj(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)
make_ahmc_kernel(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
make_ahmc_kernel(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
make_ahmc_kernel(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)

####
#### Compiler interface, i.e. tilde operators.
Expand Down Expand Up @@ -584,14 +584,14 @@ function HMCState(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Generate a phasepoint. Replaced during sample_init!
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.

# Unlink everything.
invlink!(vi, spl)

return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ AdvancedPS = "0.2"
AdvancedVI = "0.1"
Clustering = "0.14"
CmdStan = "6.0.8"
Distributions = "0.23.8, 0.24, 0.25"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion test/inference/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "io.jl" begin
@testset "inference.jl" begin
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down
9 changes: 0 additions & 9 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,10 @@
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

setadsafe(true)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.15)

setadsafe(false)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.1)
end

@turing_testset "transitions" begin
Expand Down
20 changes: 15 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,27 @@ include("test_utils/AllUtils.jl")
include("core/ad.jl")
end

@testset "samplers (without AD)" begin
include("inference/AdvancedSMC.jl")
include("inference/emcee.jl")
include("inference/ess.jl")
include("inference/is.jl")
end

Turing.setrdcache(false)
for adbackend in (:forwarddiff, :tracker, :reversediff)
Turing.setadbackend(adbackend)
@info "Testing $(adbackend)"
start = time()
@testset "inference: $adbackend" begin
@testset "samplers" begin
include("inference/gibbs.jl")
include("inference/gibbs_conditional.jl")
include("inference/hmc.jl")
include("inference/is.jl")
include("inference/mh.jl")
include("inference/ess.jl")
include("inference/emcee.jl")
include("inference/AdvancedSMC.jl")
include("inference/Inference.jl")
include("contrib/inference/dynamichmc.jl")
include("contrib/inference/sghmc.jl")
include("inference/mh.jl")
end
end

Expand All @@ -72,6 +77,11 @@ include("test_utils/AllUtils.jl")
@testset "modes" begin
include("modes/ModeEstimation.jl")
end

# Useful for
# a) discovering performance regressions,
# b) figuring out why CI is timing out.
@info "Tests for $(adbackend) took $(time() - start) seconds"
end
@testset "variational optimisers" begin
include("variational/optimisers.jl")
Expand Down

2 comments on commit 7cb94d6

@yebai
Copy link
Member

@yebai yebai commented on 7cb94d6 Jul 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41071

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.16.6 -m "<description of version>" 7cb94d60a368a2850d5eb8a3f4f329799c0eee4f
git push origin v0.16.6

Please sign in to comment.