Skip to content

Commit

Permalink
Add WithinGibbs sampler implementations (#131)
Browse files Browse the repository at this point in the history
This PR:
* Introduces `condition` and `decondition` interface similar to those of
`DynamicPPL`, addressing
#132
* Implements With-in-gibbs sampler that allow user to specify the
variable grouping and sampler(currently support AdvancedHMC, AdvancedMH,
and `JuliaBUGS.MHFromPrior`
  • Loading branch information
sunxd3 authored Dec 6, 2023
1 parent 49d38d4 commit 443f83c
Show file tree
Hide file tree
Showing 14 changed files with 496 additions and 206 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.2.5"
version = "0.3.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand All @@ -26,7 +27,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Expand All @@ -36,11 +36,11 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
JuliaBUGSAdvancedHMCExt = ["AbstractMCMC", "AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AbstractMCMC", "AdvancedMH", "MCMCChains"]
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["AbstractMCMC", "MCMCChains"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
JuliaBUGSTikzGraphsExt = ["TikzGraphs"]

[compat]
Expand Down
38 changes: 30 additions & 8 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
module JuliaBUGSAdvancedHMCExt

# The main purpose of this extension is to add `generated_quantities` to the final chains.
# So directly calling the AdvancedHMCMCMCChainsExt is not feasible.

using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS:
AbstractBUGSModel,
BUGSModel,
Gibbs,
find_generated_vars,
LogDensityContext,
evaluate!!,
_eval
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.DynamicPPL
using AbstractMCMC
using JuliaBUGS.Bijectors
using JuliaBUGS.Random
using MCMCChains: Chains
using AdvancedHMC
using AdvancedHMC: Transition, stat
import JuliaBUGS: gibbs_internal

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
Expand Down Expand Up @@ -41,4 +48,19 @@ function AbstractMCMC.bundle_samples(
)
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC
)
t, s = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true), # for more advanced usage, probably save the state or transition
)
return JuliaBUGS.setparams!!(cond_model, t.z.θ; transformed=true)
end

end
26 changes: 22 additions & 4 deletions ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module JuliaBUGSAdvancedMHExt

using AbstractMCMC
using AdvancedMH
using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Random
using JuliaBUGS.Bijectors
using JuliaBUGS.UnPack
using JuliaBUGS.DynamicPPL
using AbstractMCMC
using AdvancedMH
using MCMCChains: Chains
import JuliaBUGS: gibbs_internal

function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedMH.AbstractTransition},
Expand All @@ -35,4 +38,19 @@ function AbstractMCMC.bundle_samples(
)
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::AdvancedMH.MHSampler
)
t, s = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true),
)
return JuliaBUGS.setparams!!(cond_model, t.params; transformed=true)
end

end
4 changes: 3 additions & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ function JuliaBUGS.gen_chains(
convert(
Vector{Real},
vcat(
flattened_param_vals[i], flattened_generated_quantities[i], stats_values[i]
flattened_param_vals[i],
flattened_generated_quantities[i],
isempty(stats_values) ? [] : stats_values[i],
),
) for i in axes(samples)[1]
]
Expand Down
2 changes: 2 additions & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module JuliaBUGS

using AbstractMCMC
using AbstractPPL
using BangBang
using Bijectors
Expand Down Expand Up @@ -36,6 +37,7 @@ include("compiler_pass.jl")
include("graphs.jl")
include("model.jl")
include("logdensityproblems.jl")
include("gibbs.jl")

include("BUGSExamples/BUGSExamples.jl")

Expand Down
96 changes: 96 additions & 0 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
struct Gibbs <: AbstractMCMC.AbstractSampler
sampler_map::Dict{<:Any,<:AbstractMCMC.AbstractSampler}
end

function Gibbs(model, s::AbstractMCMC.AbstractSampler)
return Gibbs(Dict([v => s for v in model.parameters]))
end

struct MHFromPrior <: AbstractMCMC.AbstractSampler end

abstract type AbstractGibbsState end

struct GibbsState <: AbstractGibbsState
varinfo::SimpleVarInfo
conditioning_schedule::Dict
sorted_nodes_cache::Dict
end

ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::Gibbs;
model=l_model.logdensity,
kwargs...,
)
vi = deepcopy(model.varinfo)
sorted_nodes_cache = Dict{Any,Any}()

conditioning_schedule = Dict()
for vs in keys(sampler.sampler_map)
vs_complement = setdiff(model.parameters, ensure_vector(vs))
conditioning_schedule[vs_complement] = sampler.sampler_map[vs]
end

for vs in keys(conditioning_schedule)
cond_model = AbstractPPL.condition(model, vs)
sorted_nodes_cache[vs] = ensure_vector(cond_model.sorted_nodes)
end

return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, conditioning_schedule, sorted_nodes_cache)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::Gibbs,
state::AbstractGibbsState;
model=l_model.logdensity,
kwargs...,
)
vi = state.varinfo
for vs in keys(state.conditioning_schedule)
cond_model = AbstractPPL.condition(model, vs, vi, state.sorted_nodes_cache[vs])
vi = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
end
return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, state.conditioning_schedule, state.sorted_nodes_cache)
end

function gibbs_internal end

function gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::MHFromPrior
)
transformed_original = getparams(cond_model, cond_model.varinfo; transformed=true)
transformed_proposal = getparams(
cond_model, evaluate!!(cond_model, SamplingContext())[1]; transformed=true
)

vi_proposed, logp_proposed = evaluate!!(
cond_model, LogDensityContext(), transformed_proposal
)
vi, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)

if logp_proposed - logp > log(rand(rng))
vi = vi_proposed
end
return vi
end

function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{JuliaBUGS.BUGSModel},
sampler::Gibbs,
state,
::Type{T};
discard_initial=0,
kwargs...,
) where {T}
return JuliaBUGS.gen_chains(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end
27 changes: 17 additions & 10 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,29 @@ its children's other parents (reference: https://en.wikipedia.org/wiki/Markov_bl
In the case of vector, the Markov Blanket is the union of the Markov Blankets of each variable
minus the variables themselves (reference: Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov
Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50.)
In the case of M-H acceptance ratio evaluation, only the logps of the children are needed, because the logp of the parents
and co-parents are not changed (their values are still needed to compute the distributions).
"""
function markov_blanket(g::BUGSGraph, v::VarName)
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
function markov_blanket(g::BUGSGraph, v::VarName; children_only=false)
if !children_only
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
else
return stochastic_outneighbors(g, v)
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
end

function markov_blanket(g::BUGSGraph, v)
function markov_blanket(g::BUGSGraph, v; children_only=false)
blanket = VarName[]
for vn in v
blanket = vcat(blanket, markov_blanket(g, vn))
blanket = vcat(blanket, markov_blanket(g, vn; children_only=children_only))
end
return [x for x in unique(blanket) if x v]
end
Expand Down
8 changes: 0 additions & 8 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ function LogDensityProblems.dimension(model::BUGSModel)
end
end

function LogDensityProblems.dimension(model::MarkovBlanketCoveredBUGSModel)
return if model.transformed
model.mb_transformed_param_length
else
model.mb_untransformed_param_length
end
end

function LogDensityProblems.capabilities(::AbstractBUGSModel)
return LogDensityProblems.LogDensityOrder{0}()
end
Loading

0 comments on commit 443f83c

Please sign in to comment.