Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt AbstractMCMC.jl interface #259

Merged
merged 42 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f36a3cc
initial work on adopting the AbstractMCMC interface
torfjelde Apr 6, 2021
16a233c
removed old sample method since is now redundant
torfjelde Apr 6, 2021
a1494ac
AbstractMCMCKernel is now a subtype of AbstractSampler
torfjelde Apr 6, 2021
5f01c5d
fixed impl and added back old one
torfjelde Apr 6, 2021
fc9e747
overload AbstractMCMC.sample to make interface a bit nicer
torfjelde Apr 7, 2021
ba0d623
reverted renaming of abstract kernel type
torfjelde Apr 7, 2021
b6ad39c
replicate logging behavior of current AHMC when using AbstractMCMC
torfjelde Apr 8, 2021
66b37b5
initial work on tests, but require update to testing deps
torfjelde Apr 8, 2021
ea83745
Merge branch 'master' into tor/abstractmcmc
torfjelde Apr 9, 2021
f041ff6
moved abstractmcmc interface into separate file
torfjelde Apr 9, 2021
dad71a8
version bump
torfjelde Apr 12, 2021
07be7ac
only bump patch version to see if tests succeed
torfjelde Apr 12, 2021
ae16796
Merge branch 'master' into tor/abstractmcmc
torfjelde Apr 18, 2021
d3e8542
Merge branch 'master' into tor/abstractmcmc
torfjelde Apr 20, 2021
486e9d1
move away from using extras in Project.toml
torfjelde Jul 14, 2021
0d74502
Merge branch 'tor/improved-testing' into tor/abstractmcmc
torfjelde Jul 14, 2021
608442b
added integration tests for Turing.jl
torfjelde Jul 14, 2021
cb77c9d
removed usage of Turing.jl and MCMCDebugging.jl in main testsuite
torfjelde Jul 14, 2021
75e2404
fixed bug in deprecated HMCDA constructor
torfjelde Jul 14, 2021
68884d7
allow specification of which testing suites to run
torfjelde Jul 14, 2021
bfc8f4d
added Turing.jl integration tests to CI
torfjelde Jul 14, 2021
d002bfc
fixed name for integration tests
torfjelde Jul 14, 2021
1ec8262
added using AdvancedHMC in runtests.jl
torfjelde Jul 14, 2021
f7f91e2
removed some now unnecessary usings
torfjelde Jul 14, 2021
639776a
fixed a bug in the downstream testing
torfjelde Jul 14, 2021
e47a5a4
give integration tests a separate CI
torfjelde Jul 14, 2021
678196c
forgot to remove the continue-on-error from CI
torfjelde Jul 14, 2021
01ba50d
Merge branch 'tor/improved-testing' into tor/abstractmcmc
torfjelde Jul 14, 2021
5aeefe6
added convenient constructor for DifferentiableDensityModel using Ham…
torfjelde Jul 14, 2021
2b59913
fixed tests for AbstractMCMC interface
torfjelde Jul 14, 2021
d164e22
added a bunch of docstrings
torfjelde Jul 14, 2021
de4e33a
bumped minor version
torfjelde Jul 14, 2021
fa17e41
increased number of samples used in abstractmcmc tests
torfjelde Jul 14, 2021
8f8e64c
remove thinning from tests
torfjelde Jul 15, 2021
8d13ff5
make initial Leapfrog step size smaller
torfjelde Jul 15, 2021
ebaec71
Merge branch 'master' into tor/abstractmcmc
torfjelde Jul 15, 2021
ccdc832
mistakenly removed AbstractMCMC as a test dep in previous commit
torfjelde Jul 15, 2021
5a417bc
increase adaptation to see if it helps
torfjelde Jul 15, 2021
32f6ff8
ensure we drop the adaptation samples in the test
torfjelde Jul 15, 2021
13a07e8
made a mistake apparently
torfjelde Jul 15, 2021
b82a965
think I finally fixed the tests
torfjelde Jul 15, 2021
38408cf
disable progress in test
torfjelde Jul 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.2.28"
version = "0.3.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
Expand All @@ -17,6 +18,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractMCMC = "3"
ArgCheck = "1, 2"
DocStringExtensions = "0.8"
InplaceOps = "0.3"
Expand Down
6 changes: 6 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const DEBUG = convert(Bool, parse(Int, get(ENV, "DEBUG_AHMC", "0")))
using Statistics: mean, var, middle
using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
using StatsFuns: logaddexp, logsumexp
import Random
using Random: GLOBAL_RNG, AbstractRNG
using ProgressMeter: ProgressMeter
using UnPack: @unpack
Expand All @@ -16,6 +17,8 @@ using ArgCheck: @argcheck

using DocStringExtensions

import AbstractMCMC

import StatsBase: sample

include("utilities.jl")
Expand Down Expand Up @@ -128,6 +131,9 @@ include("diagnosis.jl")
include("sampler.jl")
export sample

include("abstractmcmc.jl")
export DifferentiableDensityModel

include("contrib/ad.jl")

### Init
Expand Down
293 changes: 293 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""
HMCSampler

A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl.

# Fields

$(FIELDS)

# Notes

Note that all the fields have the prefix `initial_` to indicate
that these will not necessarily correspond to the `kernel`, `metric`,
and `adaptor` after sampling.

To access the updated fields use the resulting [`HMCState`](@ref).
"""
struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler
"Initial [`AbstractMCMCKernel`](@ref)."
initial_kernel::K
"Initial [`AbstractMetric`](@ref)."
initial_metric::M
"Initial [`AbstractAdaptor`](@ref)."
initial_adaptor::A
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

"""
DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ)
DifferentiableDensityModel(ℓπ, m::Module)

A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density.

If a module `m` is given as the second argument, then `m` is assumed to be an
automatic-differentiation package and this will be used to compute the gradients.

Note that the module `m` must be imported before usage, e.g.
```julia
using Zygote: Zygote
model = DifferentiableDensityModel(ℓπ, Zygote)
```
results in a `model` which will use Zygote.jl as its AD-backend.

# Fields
$(FIELDS)
"""
struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel
"Log-density. Maps `AbstractArray` to value of the log-density."
ℓπ::Tlogπ
"Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point."
∂ℓπ∂θ::T∂logπ∂θ
end

struct DummyMetric <: AbstractMetric end
function DifferentiableDensityModel(ℓπ, m::Module)
h = Hamiltonian(DummyMetric(), ℓπ, m)
return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ)
end

"""
HMCState

Represents the state of a [`HMCSampler`](@ref).

# Fields

$(FIELDS)

"""
struct HMCState{
TTrans<:Transition,
TMetric<:AbstractMetric,
TKernel<:AbstractMCMCKernel,
TAdapt<:Adaptation.AbstractAdaptor
}
"Index of current iteration."
i::Int
"Current [`Transition`](@ref)."
transition::TTrans
"Current [`AbstractMetric`](@ref), possibly adapted."
metric::TMetric
"Current [`AbstractMCMCKernel`](@ref)."
κ::TKernel
"Current [`AbstractAdaptor`](@ref)."
adaptor::TAdapt
end

"""
$(TYPEDSIGNATURES)

A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref).
"""
function AbstractMCMC.sample(
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
kwargs...
)
return AbstractMCMC.sample(Random.GLOBAL_RNG, model, kernel, metric, adaptor, N; kwargs...)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
N::Integer;
progress = true,
verbose = false,
callback = nothing,
kwargs...
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
end

return AbstractMCMC.mcmcsample(
rng, model, sampler, N;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...
)
end

function AbstractMCMC.sample(
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
nchains::Integer;
kwargs...
)
return AbstractMCMC.sample(
Random.GLOBAL_RNG, model, kernel, metric, adaptor, N, nchains;
kwargs...
)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
nchains::Integer;
progress = true,
verbose = false,
callback = nothing,
kwargs...
)
sampler = HMCSampler(kernel, metric, adaptor)
if callback === nothing
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
progress = false # don't use AMCMC's progress-funtionality
end

return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains;
progress = progress,
verbose = verbose,
callback = callback,
kwargs...
)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
spl::HMCSampler;
init_params = nothing,
kwargs...
)
metric = spl.initial_metric
κ = spl.initial_kernel
adaptor = spl.initial_adaptor

if init_params === nothing
init_params = randn(size(metric, 1))
end

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params)

# Compute next transition and state.
state = HMCState(0, t, h.metric, κ, adaptor)

# Take actual first step.
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
spl::HMCSampler,
state::HMCState;
nadapts::Int = 0,
kwargs...
)
# Get step size
@debug "current ϵ" getstepsize(spl, state)

# Compute transition.
i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)

# Make new transition.
t = transition(rng, h, κ, t_old.z)

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next transition and state.
newstate = HMCState(i, t, h.metric, κ, adaptor)

# Return `Transition` with additional stats added.
return Transition(t.z, tstat), newstate
end


################
### Callback ###
################
"""
HMCProgressCallback

A callback to be used with AbstractMCMC.jl's interface, replicating the
logging behavior of the non-AbstractMCMC [`sample`](@ref).

# Fields
$(FIELDS)
"""
struct HMCProgressCallback{P}
"`Progress` meter from ProgressMeters.jl."
pm::P
"Specifies whether or not to use display a progress bar."
progress::Bool
"If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
verbose::Bool
end

function HMCProgressCallback(n_samples; progress=true, verbose=false)
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing
HMCProgressCallback(pm, progress, verbose)
end

function (cb::HMCProgressCallback)(
rng, model, spl, t, state, i;
nadapts = 0,
kwargs...
)
progress = cb.progress
verbose = cb.verbose
pm = cb.pm

metric = state.metric
adaptor = state.adaptor
κ = state.κ
tstat = t.stat
isadapted = tstat.is_adapt

# Update progress meter
if progress
# Do include current iteration and mass matrix
pm_next!(
pm,
(iterations=i, tstat..., mass_matrix=metric)
)
# Report finish of adapation
elseif verbose && isadapted && i == nadapts
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
end
end
4 changes: 2 additions & 2 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat}, V<:DualValue}
@warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ))
# NOTE eltype has to be inlined to avoid type stability issue; see #267
ℓπ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
ℓπ.gradient
)
ℓκ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
ℓκ.gradient
)
end
Expand Down
2 changes: 0 additions & 2 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ end
##
## Interface functions
##

function sample_init(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
h::Hamiltonian,
Expand Down Expand Up @@ -143,7 +142,6 @@ sample(
verbose::Bool=true,
progress::Bool=false
)

Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`.
- The randomness is controlled by `rng`.
- If `rng` is not provided, `GLOBAL_RNG` will be used.
Expand Down
2 changes: 1 addition & 1 deletion src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedIntegrationTime} =
## Kernel interface
##

struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel
struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check this is just the removal of a space?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. But also, now AbstractMCMCKernel is AbstractMCMC.AbstractSampler (I'll remove this alias, and make it explicit).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any side effect it might cause?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so. Anything particular in mind or just general question?

The only thing is that if we remove AbstractMCMCKernel completely, not even aliasing, e.g. CoupledHMC.jl won't work.

Maybe best approach is to just make AbstractMCMCKernel <: AbstractMCMC.AbstractSampler?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine as long as CoupledHMC is fine (with sutff like mixture kernels still working).

refreshment::R
τ::T
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand Down
Loading