-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adopt AbstractMCMC.jl interface #259
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
f36a3cc
initial work on adopting the AbstractMCMC interface
torfjelde 16a233c
removed old sample method since is now redundant
torfjelde a1494ac
AbstractMCMCKernel is now a subtype of AbstractSampler
torfjelde 5f01c5d
fixed impl and added back old one
torfjelde fc9e747
overload AbstractMCMC.sample to make interface a bit nicer
torfjelde ba0d623
reverted renaming of abstract kernel type
torfjelde b6ad39c
replicate logging behavior of current AHMC when using AbstractMCMC
torfjelde 66b37b5
initial work on tests, but require update to testing deps
torfjelde ea83745
Merge branch 'master' into tor/abstractmcmc
torfjelde f041ff6
moved abstractmcmc interface into separate file
torfjelde dad71a8
version bump
torfjelde 07be7ac
only bump patch version to see if tests succeed
torfjelde ae16796
Merge branch 'master' into tor/abstractmcmc
torfjelde d3e8542
Merge branch 'master' into tor/abstractmcmc
torfjelde 486e9d1
move away from using extras in Project.toml
torfjelde 0d74502
Merge branch 'tor/improved-testing' into tor/abstractmcmc
torfjelde 608442b
added integration tests for Turing.jl
torfjelde cb77c9d
removed usage of Turing.jl and MCMCDebugging.jl in main testsuite
torfjelde 75e2404
fixed bug in deprecated HMCDA constructor
torfjelde 68884d7
allow specification of which testing suites to run
torfjelde bfc8f4d
added Turing.jl integration tests to CI
torfjelde d002bfc
fixed name for integration tests
torfjelde 1ec8262
added using AdvancedHMC in runtests.jl
torfjelde f7f91e2
removed some now unnecessary usings
torfjelde 639776a
fixed a bug in the downstream testing
torfjelde e47a5a4
give integration tests a separate CI
torfjelde 678196c
forgot to remove the continue-on-error from CI
torfjelde 01ba50d
Merge branch 'tor/improved-testing' into tor/abstractmcmc
torfjelde 5aeefe6
added convenient constructor for DifferentiableDensityModel using Ham…
torfjelde 2b59913
fixed tests for AbstractMCMC interface
torfjelde d164e22
added a bunch of docstrings
torfjelde de4e33a
bumped minor version
torfjelde fa17e41
increased number of samples used in abstractmcmc tests
torfjelde 8f8e64c
remove thinning from tests
torfjelde 8d13ff5
make initial Leapfrog step size smaller
torfjelde ebaec71
Merge branch 'master' into tor/abstractmcmc
torfjelde ccdc832
mistakenly removed AbstractMCMC as a test dep in previous commit
torfjelde 5a417bc
increase adaptation to see if it helps
torfjelde 32f6ff8
ensure we drop the adaptation samples in the test
torfjelde 13a07e8
made a mistake apparently
torfjelde b82a965
think I finally fixed the tests
torfjelde 38408cf
disable progress in test
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
""" | ||
HMCSampler | ||
|
||
A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl. | ||
|
||
# Fields | ||
|
||
$(FIELDS) | ||
|
||
# Notes | ||
|
||
Note that all the fields have the prefix `initial_` to indicate | ||
that these will not necessarily correspond to the `kernel`, `metric`, | ||
and `adaptor` after sampling. | ||
|
||
To access the updated fields use the resulting [`HMCState`](@ref). | ||
""" | ||
struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler | ||
"Initial [`AbstractMCMCKernel`](@ref)." | ||
initial_kernel::K | ||
"Initial [`AbstractMetric`](@ref)." | ||
initial_metric::M | ||
"Initial [`AbstractAdaptor`](@ref)." | ||
initial_adaptor::A | ||
end | ||
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation()) | ||
|
||
""" | ||
DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ) | ||
DifferentiableDensityModel(ℓπ, m::Module) | ||
|
||
A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density. | ||
|
||
If a module `m` is given as the second argument, then `m` is assumed to be an | ||
automatic-differentiation package and this will be used to compute the gradients. | ||
|
||
Note that the module `m` must be imported before usage, e.g. | ||
```julia | ||
using Zygote: Zygote | ||
model = DifferentiableDensityModel(ℓπ, Zygote) | ||
``` | ||
results in a `model` which will use Zygote.jl as its AD-backend. | ||
|
||
# Fields | ||
$(FIELDS) | ||
""" | ||
struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel | ||
"Log-density. Maps `AbstractArray` to value of the log-density." | ||
ℓπ::Tlogπ | ||
"Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point." | ||
∂ℓπ∂θ::T∂logπ∂θ | ||
end | ||
|
||
struct DummyMetric <: AbstractMetric end | ||
function DifferentiableDensityModel(ℓπ, m::Module) | ||
h = Hamiltonian(DummyMetric(), ℓπ, m) | ||
return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ) | ||
end | ||
|
||
""" | ||
HMCState | ||
|
||
Represents the state of a [`HMCSampler`](@ref). | ||
|
||
# Fields | ||
|
||
$(FIELDS) | ||
|
||
""" | ||
struct HMCState{ | ||
TTrans<:Transition, | ||
TMetric<:AbstractMetric, | ||
TKernel<:AbstractMCMCKernel, | ||
TAdapt<:Adaptation.AbstractAdaptor | ||
} | ||
"Index of current iteration." | ||
i::Int | ||
"Current [`Transition`](@ref)." | ||
transition::TTrans | ||
"Current [`AbstractMetric`](@ref), possibly adapted." | ||
metric::TMetric | ||
"Current [`AbstractMCMCKernel`](@ref)." | ||
κ::TKernel | ||
"Current [`AbstractAdaptor`](@ref)." | ||
adaptor::TAdapt | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
|
||
A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref). | ||
""" | ||
function AbstractMCMC.sample( | ||
model::DifferentiableDensityModel, | ||
kernel::AbstractMCMCKernel, | ||
metric::AbstractMetric, | ||
adaptor::AbstractAdaptor, | ||
N::Integer; | ||
kwargs... | ||
) | ||
return AbstractMCMC.sample(Random.GLOBAL_RNG, model, kernel, metric, adaptor, N; kwargs...) | ||
end | ||
|
||
function AbstractMCMC.sample( | ||
rng::Random.AbstractRNG, | ||
model::DifferentiableDensityModel, | ||
kernel::AbstractMCMCKernel, | ||
metric::AbstractMetric, | ||
adaptor::AbstractAdaptor, | ||
N::Integer; | ||
progress = true, | ||
verbose = false, | ||
callback = nothing, | ||
kwargs... | ||
) | ||
sampler = HMCSampler(kernel, metric, adaptor) | ||
if callback === nothing | ||
callback = HMCProgressCallback(N, progress = progress, verbose = verbose) | ||
progress = false # don't use AMCMC's progress-funtionality | ||
end | ||
|
||
return AbstractMCMC.mcmcsample( | ||
rng, model, sampler, N; | ||
progress = progress, | ||
verbose = verbose, | ||
callback = callback, | ||
kwargs... | ||
) | ||
end | ||
|
||
function AbstractMCMC.sample( | ||
model::DifferentiableDensityModel, | ||
kernel::AbstractMCMCKernel, | ||
metric::AbstractMetric, | ||
adaptor::AbstractAdaptor, | ||
parallel::AbstractMCMC.AbstractMCMCParallel, | ||
N::Integer, | ||
nchains::Integer; | ||
kwargs... | ||
) | ||
return AbstractMCMC.sample( | ||
Random.GLOBAL_RNG, model, kernel, metric, adaptor, N, nchains; | ||
kwargs... | ||
) | ||
end | ||
|
||
function AbstractMCMC.sample( | ||
rng::Random.AbstractRNG, | ||
model::DifferentiableDensityModel, | ||
kernel::AbstractMCMCKernel, | ||
metric::AbstractMetric, | ||
adaptor::AbstractAdaptor, | ||
parallel::AbstractMCMC.AbstractMCMCParallel, | ||
N::Integer, | ||
nchains::Integer; | ||
progress = true, | ||
verbose = false, | ||
callback = nothing, | ||
kwargs... | ||
) | ||
sampler = HMCSampler(kernel, metric, adaptor) | ||
if callback === nothing | ||
callback = HMCProgressCallback(N, progress = progress, verbose = verbose) | ||
progress = false # don't use AMCMC's progress-funtionality | ||
end | ||
|
||
return AbstractMCMC.mcmcsample( | ||
rng, model, sampler, parallel, N, nchains; | ||
progress = progress, | ||
verbose = verbose, | ||
callback = callback, | ||
kwargs... | ||
) | ||
end | ||
|
||
function AbstractMCMC.step( | ||
rng::AbstractRNG, | ||
model::DifferentiableDensityModel, | ||
spl::HMCSampler; | ||
init_params = nothing, | ||
kwargs... | ||
) | ||
metric = spl.initial_metric | ||
κ = spl.initial_kernel | ||
adaptor = spl.initial_adaptor | ||
|
||
if init_params === nothing | ||
init_params = randn(size(metric, 1)) | ||
end | ||
|
||
# Construct the hamiltonian using the initial metric | ||
hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ) | ||
|
||
# Get an initial sample. | ||
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params) | ||
|
||
# Compute next transition and state. | ||
state = HMCState(0, t, h.metric, κ, adaptor) | ||
|
||
# Take actual first step. | ||
return AbstractMCMC.step(rng, model, spl, state; kwargs...) | ||
end | ||
|
||
function AbstractMCMC.step( | ||
rng::AbstractRNG, | ||
model::DifferentiableDensityModel, | ||
spl::HMCSampler, | ||
state::HMCState; | ||
nadapts::Int = 0, | ||
kwargs... | ||
) | ||
# Get step size | ||
@debug "current ϵ" getstepsize(spl, state) | ||
|
||
# Compute transition. | ||
i = state.i + 1 | ||
t_old = state.transition | ||
adaptor = state.adaptor | ||
κ = state.κ | ||
metric = state.metric | ||
|
||
# Reconstruct hamiltonian. | ||
h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ) | ||
|
||
# Make new transition. | ||
t = transition(rng, h, κ, t_old.z) | ||
|
||
# Adapt h and spl. | ||
tstat = stat(t) | ||
h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate) | ||
tstat = merge(tstat, (is_adapt=isadapted,)) | ||
|
||
# Compute next transition and state. | ||
newstate = HMCState(i, t, h.metric, κ, adaptor) | ||
|
||
# Return `Transition` with additional stats added. | ||
return Transition(t.z, tstat), newstate | ||
end | ||
|
||
|
||
################ | ||
### Callback ### | ||
################ | ||
""" | ||
HMCProgressCallback | ||
|
||
A callback to be used with AbstractMCMC.jl's interface, replicating the | ||
logging behavior of the non-AbstractMCMC [`sample`](@ref). | ||
|
||
# Fields | ||
$(FIELDS) | ||
""" | ||
struct HMCProgressCallback{P} | ||
"`Progress` meter from ProgressMeters.jl." | ||
pm::P | ||
"Specifies whether or not to use display a progress bar." | ||
progress::Bool | ||
"If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation." | ||
verbose::Bool | ||
end | ||
|
||
function HMCProgressCallback(n_samples; progress=true, verbose=false) | ||
pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing | ||
HMCProgressCallback(pm, progress, verbose) | ||
end | ||
|
||
function (cb::HMCProgressCallback)( | ||
rng, model, spl, t, state, i; | ||
nadapts = 0, | ||
kwargs... | ||
) | ||
progress = cb.progress | ||
verbose = cb.verbose | ||
pm = cb.pm | ||
|
||
metric = state.metric | ||
adaptor = state.adaptor | ||
κ = state.κ | ||
tstat = t.stat | ||
isadapted = tstat.is_adapt | ||
|
||
# Update progress meter | ||
if progress | ||
# Do include current iteration and mass matrix | ||
pm_next!( | ||
pm, | ||
(iterations=i, tstat..., mass_matrix=metric) | ||
) | ||
# Report finish of adapation | ||
elseif verbose && isadapted && i == nadapts | ||
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check this is just the removal of a space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. But also, now
AbstractMCMCKernel
isAbstractMCMC.AbstractSampler
(I'll remove this alias, and make it explicit).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any side effect it might cause?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think so. Anything particular in mind or just general question?
The only thing is that if we remove
AbstractMCMCKernel
completely, not even aliasing, e.g. CoupledHMC.jl won't work.Maybe best approach is to just make
AbstractMCMCKernel <: AbstractMCMC.AbstractSampler
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine as long as CoupledHMC is fine (with sutff like mixture kernels still working).