Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithinGibbs sampler implementations #131

Merged
merged 33 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
defe463
Update the markov blanket model
sunxd3 Nov 6, 2023
3f70389
update some tests
sunxd3 Nov 6, 2023
36e0db0
add option to return children only for mb
sunxd3 Nov 8, 2023
1171525
some dev scratchpad updates
sunxd3 Nov 8, 2023
cf39a6a
Remove dev codes to local space
sunxd3 Nov 10, 2023
056c50d
Add first draft for Gibbs
sunxd3 Nov 10, 2023
a7eb45c
Some formatting
sunxd3 Nov 10, 2023
2186462
MHwithinGIbbs works, but very slow
sunxd3 Nov 11, 2023
1f7a336
Formatting
sunxd3 Nov 11, 2023
bc9196b
Allow empty stats for `gen_chain`
sunxd3 Nov 11, 2023
1992fc5
Remove TODO item
sunxd3 Nov 11, 2023
5b3ed55
Adjust for performance improve
sunxd3 Nov 12, 2023
9fda137
Add test for simple gibbs implementation
sunxd3 Nov 13, 2023
b11cace
Add HMC draft
sunxd3 Nov 14, 2023
7acbcb1
Remove unnecessary files and
sunxd3 Nov 23, 2023
65bd0b8
Formatting and fix tests
sunxd3 Nov 23, 2023
627703f
Some refactor
sunxd3 Nov 23, 2023
a5a4891
small improvement
sunxd3 Nov 23, 2023
6079f23
Merge branch 'master' into sunxd/simple_gibbs
sunxd3 Nov 23, 2023
78c3895
Minor refactoring
sunxd3 Nov 23, 2023
a8b3964
formatting
sunxd3 Nov 23, 2023
54188f5
adjust test numerics
sunxd3 Nov 24, 2023
2a25358
Use StableRNGs
sunxd3 Nov 24, 2023
5b6fe86
Add some documentation
sunxd3 Nov 24, 2023
3b2470e
use larger atol
sunxd3 Nov 24, 2023
9b2405c
Formatting
sunxd3 Nov 24, 2023
ecedf09
Fix tests
sunxd3 Nov 24, 2023
55d2278
Remove StableRNGs, fix bugs in `getparams`
sunxd3 Nov 24, 2023
3bbbfbe
Larger atol; change `WithinGibbs` to `Gibbs`
sunxd3 Nov 28, 2023
c509737
Reverse name change in SymbolicExt; formatting
sunxd3 Nov 28, 2023
2939eaa
Bump minor version
sunxd3 Nov 28, 2023
5c02e84
Use rtol instead of atol
sunxd3 Nov 29, 2023
120d014
more generous rtol
sunxd3 Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.2.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand All @@ -26,7 +27,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Expand All @@ -36,11 +36,11 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
JuliaBUGSAdvancedHMCExt = ["AbstractMCMC", "AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AbstractMCMC", "AdvancedMH", "MCMCChains"]
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["AbstractMCMC", "MCMCChains"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
JuliaBUGSTikzGraphsExt = ["TikzGraphs"]

[compat]
Expand Down
27 changes: 19 additions & 8 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,25 @@ function JuliaBUGS.gen_chains(
collect(Iterators.flatten(gq)) for gq in generated_quantities
]

vals = [
convert(
Vector{Real},
vcat(
flattened_param_vals[i], flattened_generated_quantities[i], stats_values[i]
),
) for i in axes(samples)[1]
]
if !isempty(stats_values)
vals = [
convert(
Vector{Real},
vcat(
flattened_param_vals[i], flattened_generated_quantities[i], stats_values[i]
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
),
) for i in axes(samples)[1]
]
else
vals = [
convert(
Vector{Real},
vcat(
flattened_param_vals[i], flattened_generated_quantities[i]
),
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
) for i in axes(samples)[1]
]
end

@assert length(vals[1]) ==
length(param_name_leaves) +
Expand Down
3 changes: 3 additions & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module JuliaBUGS

using AbstractMCMC
using AbstractPPL
using BangBang
using Bijectors
Expand Down Expand Up @@ -33,7 +34,9 @@ include("variable_types.jl")
include("compiler_pass.jl")
include("graphs.jl")
include("model.jl")
include("markov_blanket_model.jl")
include("logdensityproblems.jl")
include("gibbs.jl")

include("BUGSExamples/BUGSExamples.jl")

Expand Down
101 changes: 101 additions & 0 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
struct WithinGibbs{S} <: AbstractMCMC.AbstractSampler
sampler_map::S # map from a group of variables to the sampler
end

struct MHFromPrior end

struct MHState
varinfo
markov_blanket_cache
sorted_nodes_cache
end
# TODO: need to cache the markov blankets to avoid recomputing them

ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::WithinGibbs{MHFromPrior};
model=l_model.logdensity,
kwargs...,
)
vi = deepcopy(model.varinfo)
markov_blanket_cache = Dict{Any,Any}()
sorted_nodes_cache = Dict{Any,Any}()
for v in model.parameters
mb_model = JuliaBUGS.MarkovBlanketBUGSModel(model, v)
markov_blanket_cache[v] = ensure_vector(mb_model.members)
sorted_nodes_cache[v] = ensure_vector(mb_model.sorted_nodes)
end

init_state = MHState(vi, markov_blanket_cache, sorted_nodes_cache)

vi = gibbs_steps(rng, model, sampler, init_state)
return getparams(model, vi), MHState(vi, markov_blanket_cache, sorted_nodes_cache)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::WithinGibbs{MHFromPrior},
state::MHState;
model=l_model.logdensity,
kwargs...,
)
vi = state.varinfo
vi = gibbs_steps(rng, model, sampler, state)
return getparams(model, vi),
MHState(vi, state.markov_blanket_cache, state.sorted_nodes_cache)
end

function gibbs_steps(
rng::Random.AbstractRNG,
model::BUGSModel,
::WithinGibbs{MHFromPrior},
state,
var_iterator=model.parameters,
)
g = model.g
vi = state.varinfo
for v in var_iterator
ni = g[v]
args = Dict(getsym(arg) => vi[arg] for arg in ni.node_args)
dist = _eval(ni.node_function_expr.args[2], args)

transformed_original = ensure_vector(Bijectors.link(dist, vi[v]))
transformed_proposal = ensure_vector(Bijectors.link(dist, rand(rng, dist)))

mb_model = JuliaBUGS.MarkovBlanketBUGSModel(
vi,
ensure_vector(v),
state.markov_blanket_cache[v],
state.sorted_nodes_cache[v],
model,
)
_, logp = evaluate!!(mb_model, LogDensityContext(), transformed_original)
vi_proposed, logp_proposed = evaluate!!(
mb_model, LogDensityContext(), transformed_proposal
)

logr = logp_proposed - logp
if logr > log(rand(rng))
vi = vi_proposed
end
end
return vi
end

function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{JuliaBUGS.BUGSModel},
sampler::WithinGibbs{MHFromPrior},
state,
::Type{T};
discard_initial=0,
kwargs...,
) where {T}
return JuliaBUGS.gen_chains(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end
27 changes: 17 additions & 10 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,29 @@ its children's other parents (reference: https://en.wikipedia.org/wiki/Markov_bl
In the case of vector, the Markov Blanket is the union of the Markov Blankets of each variable
minus the variables themselves (reference: Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov
Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50.)

In the case of M-H acceptance ratio evaluation, only the logps of the children are needed, because the logp of the parents
and co-parents are not changed (their values are still needed to compute the distributions).
"""
function markov_blanket(g::BUGSGraph, v::VarName)
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
function markov_blanket(g::BUGSGraph, v::VarName; children_only=false)
if !children_only
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
else
return stochastic_outneighbors(g, v)
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
end

function markov_blanket(g::BUGSGraph, v)
function markov_blanket(g::BUGSGraph, v; children_only=false)
blanket = VarName[]
for vn in v
blanket = vcat(blanket, markov_blanket(g, vn))
blanket = vcat(blanket, markov_blanket(g, vn; children_only=children_only))
end
return [x for x in unique(blanket) if x ∉ v]
end
Expand Down
8 changes: 0 additions & 8 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ function LogDensityProblems.dimension(model::BUGSModel)
end
end

function LogDensityProblems.dimension(model::MarkovBlanketCoveredBUGSModel)
return if model.transformed
model.mb_transformed_param_length
else
model.mb_untransformed_param_length
end
end

function LogDensityProblems.capabilities(::AbstractBUGSModel)
return LogDensityProblems.LogDensityOrder{0}()
end
101 changes: 101 additions & 0 deletions src/markov_blanket_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
struct MarkovBlanketBUGSModel <: AbstractBUGSModel
varinfo # parent_model.varinfo serves as prototype instead of concrete state, we'll save the varinfo in gibbs state, and this is to be passed to LogDensityProblems
target_vars::Vector{VarName}
members::Vector{VarName}
sorted_nodes::Vector{VarName}
parent_model::BUGSModel
end

function MarkovBlanketBUGSModel(
m::BUGSModel, var_group::Union{VarName,Vector{VarName}}, varinfo=m.varinfo
)
var_group = var_group isa VarName ? [var_group] : var_group

# check inputs
non_vars = VarName[]
logical_vars = VarName[]
for var in var_group
if var ∉ labels(m.g)
push!(non_vars, var)
elseif m.g[var].node_type == Logical
push!(logical_vars, var)
end
end
isempty(non_vars) || error("Variables $(non_vars) are not in the model")
isempty(logical_vars) ||
warn("Variables $(logical_vars) are not stochastic variables, they will be ignored")

blanket = markov_blanket(m.g, var_group)
blanket_with_vars = union(blanket, var_group)
sorted_blanket_with_vars = VarName[]
for vn in m.sorted_nodes # keep the order of the original model
if vn in blanket_with_vars
push!(sorted_blanket_with_vars, vn)
end
end
return MarkovBlanketBUGSModel(varinfo, var_group, blanket, sorted_blanket_with_vars, m)
end

# need a function that compute the logp of the target_vars

function AbstractPPL.evaluate!!(
model::MarkovBlanketBUGSModel, ::LogDensityContext, flattened_values::AbstractVector
)
transformed = model.parent_model.transformed
var_lengths = if transformed
model.parent_model.transformed_var_lengths
else
model.parent_model.untransformed_var_lengths
end
param_length = sum(var_lengths[v] for v in model.target_vars)
sorted_nodes = model.sorted_nodes
@assert length(flattened_values) == param_length
g = model.parent_model.g
vi = deepcopy(model.parent_model.varinfo)
current_idx = 1
logp = 0.0
for vn in sorted_nodes
ni = g[vn]
@unpack node_type, node_function_expr, node_args = ni
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
if vn in model.target_vars
if transformed
l = var_lengths[vn]
value_transformed = flattened_values[current_idx:(current_idx + l - 1)]
current_idx += l
value, logjac = DynamicPPL.with_logabsdet_jacobian_and_reconstruct(
Bijectors.inverse(bijector(dist)), dist, value_transformed
)
logp += logpdf(dist, value) + logjac
vi = setindex!!(vi, value, vn)
else
l = var_lengths[vn]
value = DynamicPPL.reconstruct(
dist, flattened_values[current_idx:(current_idx + l - 1)]
)
current_idx += l
logp += logpdf(dist, value)
vi = setindex!!(vi, value, vn)
end
else
logp += logpdf(dist, vi[vn])
end
end
end
return vi, logp
end

function LogDensityProblems.dimension(model::MarkovBlanketBUGSModel)
length_dict = if model.parent_model.transformed
model.parent_model.transformed_var_lengths
else
model.parent_model.untransformed_var_lengths
end
return sum(length_dict[v] for v in model.target_vars)
end
Loading