Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated HMC implementation for new AHMC version #1660

Merged
merged 19 commits into from
Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.5"
version = "0.16.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -34,7 +34,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "3.2"
AdvancedHMC = "0.2.24"
AdvancedHMC = "0.3.0"
AdvancedMH = "0.6"
AdvancedPS = "0.2.4"
AdvancedVI = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function gibbs_state(
state.z.θ .= θ_old
z = state.z

return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor)
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

struct HMCState{
TV<:AbstractVarInfo,
TTraj<:AHMC.AbstractTrajectory,
TKernel<:AHMC.HMCKernel,
THam<:AHMC.Hamiltonian,
PhType<:AHMC.PhasePoint,
TAdapt<:AHMC.Adaptation.AbstractAdaptor,
}
vi::TV
i::Int
traj::TTraj
kernel::TKernel
hamiltonian::THam
z::PhType
adaptor::TAdapt
Expand Down Expand Up @@ -190,18 +190,18 @@ function DynamicPPL.initialstep(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Create initial transition and state.
# Already perform one step since otherwise we don't get any statistics.
t = AHMC.step(rng, hamiltonian, traj, z)
t = AHMC.transition(rng, hamiltonian, kernel, z)

# Adaptation
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, traj, adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, kernel, adaptor,
1, nadapts, t.z.θ, t.stat.acceptance_rate)
end

Expand All @@ -215,7 +215,7 @@ function DynamicPPL.initialstep(
end

transition = HMCTransition(vi, t)
state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)

return transition, state
end
Expand All @@ -234,16 +234,16 @@ function AbstractMCMC.step(
# Compute transition.
hamiltonian = state.hamiltonian
z = state.z
t = AHMC.step(rng, hamiltonian, state.traj, z)
t = AHMC.transition(rng, hamiltonian, state.kernel, z)

# Adaptation
i = state.i + 1
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, state.traj, state.adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, state.kernel, state.adaptor,
i, nadapts, t.z.θ, t.stat.acceptance_rate)
else
traj = state.traj
kernel = state.kernel
end

# Update variables
Expand All @@ -255,7 +255,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = HMCTransition(vi, t)
newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)

return transition, newstate
end
Expand Down Expand Up @@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
end

gen_traj(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
gen_traj(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)
make_ahmc_kernel(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
make_ahmc_kernel(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
make_ahmc_kernel(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)

####
#### Compiler interface, i.e. tilde operators.
Expand Down Expand Up @@ -584,14 +584,14 @@ function HMCState(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Generate a phasepoint. Replaced during sample_init!
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.

# Unlink everything.
invlink!(vi, spl)

return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ AdvancedPS = "0.2"
AdvancedVI = "0.1"
Clustering = "0.14"
CmdStan = "6.0.8"
Distributions = "0.23.8, 0.24, 0.25"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion test/inference/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "io.jl" begin
@testset "inference.jl" begin
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down
9 changes: 0 additions & 9 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,10 @@
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

setadsafe(true)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.15)

setadsafe(false)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.1)
end

@turing_testset "transitions" begin
Expand Down
20 changes: 15 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,27 @@ include("test_utils/AllUtils.jl")
include("core/ad.jl")
end

@testset "samplers (without AD)" begin
include("inference/AdvancedSMC.jl")
include("inference/emcee.jl")
include("inference/ess.jl")
include("inference/is.jl")
end

Turing.setrdcache(false)
for adbackend in (:forwarddiff, :tracker, :reversediff)
Turing.setadbackend(adbackend)
@info "Testing $(adbackend)"
start = time()
@testset "inference: $adbackend" begin
@testset "samplers" begin
include("inference/gibbs.jl")
include("inference/gibbs_conditional.jl")
include("inference/hmc.jl")
include("inference/is.jl")
include("inference/mh.jl")
include("inference/ess.jl")
include("inference/emcee.jl")
include("inference/AdvancedSMC.jl")
include("inference/Inference.jl")
include("contrib/inference/dynamichmc.jl")
include("contrib/inference/sghmc.jl")
include("inference/mh.jl")
end
end

Expand All @@ -72,6 +77,11 @@ include("test_utils/AllUtils.jl")
@testset "modes" begin
include("modes/ModeEstimation.jl")
end

# Useful for
# a) discovering performance regressions,
# b) figuring out why CI is timing out.
@info "Tests for $(adbackend) took $(time() - start) seconds"
end
@testset "variational optimisers" begin
include("variational/optimisers.jl")
Expand Down