Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithinGibbs sampler implementations #131

Merged
merged 33 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
defe463
Update the markov blanket model
sunxd3 Nov 6, 2023
3f70389
update some tests
sunxd3 Nov 6, 2023
36e0db0
add option to return children only for mb
sunxd3 Nov 8, 2023
1171525
some dev scratchpad updates
sunxd3 Nov 8, 2023
cf39a6a
Remove dev codes to local space
sunxd3 Nov 10, 2023
056c50d
Add first draft for Gibbs
sunxd3 Nov 10, 2023
a7eb45c
Some formatting
sunxd3 Nov 10, 2023
2186462
MHwithinGIbbs works, but very slow
sunxd3 Nov 11, 2023
1f7a336
Formatting
sunxd3 Nov 11, 2023
bc9196b
Allow empty stats for `gen_chain`
sunxd3 Nov 11, 2023
1992fc5
Remove TODO item
sunxd3 Nov 11, 2023
5b3ed55
Adjust for performance improve
sunxd3 Nov 12, 2023
9fda137
Add test for simple gibbs implementation
sunxd3 Nov 13, 2023
b11cace
Add HMC draft
sunxd3 Nov 14, 2023
7acbcb1
Remove unnecessary files and
sunxd3 Nov 23, 2023
65bd0b8
Formatting and fix tests
sunxd3 Nov 23, 2023
627703f
Some refactor
sunxd3 Nov 23, 2023
a5a4891
small improvement
sunxd3 Nov 23, 2023
6079f23
Merge branch 'master' into sunxd/simple_gibbs
sunxd3 Nov 23, 2023
78c3895
Minor refactoring
sunxd3 Nov 23, 2023
a8b3964
formatting
sunxd3 Nov 23, 2023
54188f5
adjust test numerics
sunxd3 Nov 24, 2023
2a25358
Use StableRNGs
sunxd3 Nov 24, 2023
5b6fe86
Add some documentation
sunxd3 Nov 24, 2023
3b2470e
use larger atol
sunxd3 Nov 24, 2023
9b2405c
Formatting
sunxd3 Nov 24, 2023
ecedf09
Fix tests
sunxd3 Nov 24, 2023
55d2278
Remove StableRNGs, fix bugs in `getparams`
sunxd3 Nov 24, 2023
3bbbfbe
Larger atol; change `WithinGibbs` to `Gibbs`
sunxd3 Nov 28, 2023
c509737
Reverse name change in SymbolicExt; formatting
sunxd3 Nov 28, 2023
2939eaa
Bump minor version
sunxd3 Nov 28, 2023
5c02e84
Use rtol instead of atol
sunxd3 Nov 29, 2023
120d014
more generous rtol
sunxd3 Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.2.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand All @@ -26,7 +27,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Expand All @@ -36,11 +36,11 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
JuliaBUGSAdvancedHMCExt = ["AbstractMCMC", "AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AbstractMCMC", "AdvancedMH", "MCMCChains"]
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["AbstractMCMC", "MCMCChains"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
JuliaBUGSTikzGraphsExt = ["TikzGraphs"]

[compat]
Expand Down
4 changes: 2 additions & 2 deletions archive/SymbolicExt/src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ end

abstract type GibbsSampler <: AbstractMCMC.AbstractSampler end

abstract type MHWithinGibbs <: GibbsSampler end
abstract type MHGibbs <: GibbsSampler end

struct ProposeFromPrior <: MHWithinGibbs end
struct ProposeFromPrior <: MHGibbs end

function AbstractMCMC.step(
rng::Random.AbstractRNG, model::GraphModel, sampler::GibbsSampler; kwargs...
Expand Down
38 changes: 30 additions & 8 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
module JuliaBUGSAdvancedHMCExt

# The main purpose of this extension is to add `generated_quantities` to the final chains.
# So directly calling the AdvancedHMCMCMCChainsExt is not feasible.

using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS:
AbstractBUGSModel,
BUGSModel,
Gibbs,
find_generated_vars,
LogDensityContext,
evaluate!!,
_eval
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.DynamicPPL
using AbstractMCMC
using JuliaBUGS.Bijectors
using JuliaBUGS.Random
using MCMCChains: Chains
using AdvancedHMC
using AdvancedHMC: Transition, stat
import JuliaBUGS: gibbs_internal

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
Expand Down Expand Up @@ -41,4 +48,19 @@ function AbstractMCMC.bundle_samples(
)
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC
)
t, s = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true), # for more advanced usage, probably save the state or transition
)
return JuliaBUGS.setparams!!(cond_model, t.z.θ; transformed=true)
end

end
26 changes: 22 additions & 4 deletions ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module JuliaBUGSAdvancedMHExt

using AbstractMCMC
using AdvancedMH
using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Random
using JuliaBUGS.Bijectors
using JuliaBUGS.UnPack
using JuliaBUGS.DynamicPPL
using AbstractMCMC
using AdvancedMH
using MCMCChains: Chains
import JuliaBUGS: gibbs_internal

function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedMH.AbstractTransition},
Expand All @@ -35,4 +38,19 @@
)
end

function JuliaBUGS.gibbs_internal(

Check warning on line 41 in ext/JuliaBUGSAdvancedMHExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuliaBUGSAdvancedMHExt.jl#L41

Added line #L41 was not covered by tests
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::AdvancedMH.MHSampler
)
t, s = AbstractMCMC.step(

Check warning on line 44 in ext/JuliaBUGSAdvancedMHExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuliaBUGSAdvancedMHExt.jl#L44

Added line #L44 was not covered by tests
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true),
)
return JuliaBUGS.setparams!!(cond_model, t.params; transformed=true)

Check warning on line 53 in ext/JuliaBUGSAdvancedMHExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuliaBUGSAdvancedMHExt.jl#L53

Added line #L53 was not covered by tests
end

end
4 changes: 3 additions & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ function JuliaBUGS.gen_chains(
convert(
Vector{Real},
vcat(
flattened_param_vals[i], flattened_generated_quantities[i], stats_values[i]
flattened_param_vals[i],
flattened_generated_quantities[i],
isempty(stats_values) ? [] : stats_values[i],
),
) for i in axes(samples)[1]
]
Expand Down
2 changes: 2 additions & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module JuliaBUGS

using AbstractMCMC
using AbstractPPL
using BangBang
using Bijectors
Expand Down Expand Up @@ -36,6 +37,7 @@ include("compiler_pass.jl")
include("graphs.jl")
include("model.jl")
include("logdensityproblems.jl")
include("gibbs.jl")

include("BUGSExamples/BUGSExamples.jl")

Expand Down
96 changes: 96 additions & 0 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
struct Gibbs <: AbstractMCMC.AbstractSampler
sampler_map::Dict{<:Any,<:AbstractMCMC.AbstractSampler}
end

function Gibbs(model, s::AbstractMCMC.AbstractSampler)
return Gibbs(Dict([v => s for v in model.parameters]))
end

struct MHFromPrior <: AbstractMCMC.AbstractSampler end

abstract type AbstractGibbsState end

struct GibbsState <: AbstractGibbsState
varinfo::SimpleVarInfo
conditioning_schedule::Dict
sorted_nodes_cache::Dict
end

ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::Gibbs;
model=l_model.logdensity,
kwargs...,
)
vi = deepcopy(model.varinfo)
sorted_nodes_cache = Dict{Any,Any}()

conditioning_schedule = Dict()
for vs in keys(sampler.sampler_map)
vs_complement = setdiff(model.parameters, ensure_vector(vs))
conditioning_schedule[vs_complement] = sampler.sampler_map[vs]
end

for vs in keys(conditioning_schedule)
cond_model = AbstractPPL.condition(model, vs)
sorted_nodes_cache[vs] = ensure_vector(cond_model.sorted_nodes)
end

return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, conditioning_schedule, sorted_nodes_cache)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{BUGSModel},
sampler::Gibbs,
state::AbstractGibbsState;
model=l_model.logdensity,
kwargs...,
)
vi = state.varinfo
for vs in keys(state.conditioning_schedule)
cond_model = AbstractPPL.condition(model, vs, vi, state.sorted_nodes_cache[vs])
vi = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
end
return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, state.conditioning_schedule, state.sorted_nodes_cache)
end

function gibbs_internal end

function gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::MHFromPrior
)
transformed_original = getparams(cond_model, cond_model.varinfo; transformed=true)
transformed_proposal = getparams(
cond_model, evaluate!!(cond_model, SamplingContext())[1]; transformed=true
)

vi_proposed, logp_proposed = evaluate!!(
cond_model, LogDensityContext(), transformed_proposal
)
vi, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)

if logp_proposed - logp > log(rand(rng))
vi = vi_proposed
end
return vi
end

function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{JuliaBUGS.BUGSModel},
sampler::Gibbs,
state,
::Type{T};
discard_initial=0,
kwargs...,
) where {T}
return JuliaBUGS.gen_chains(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end
27 changes: 17 additions & 10 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,29 @@
In the case of vector, the Markov Blanket is the union of the Markov Blankets of each variable
minus the variables themselves (reference: Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov
Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50.)

In the case of M-H acceptance ratio evaluation, only the logps of the children are needed, because the logp of the parents
and co-parents are not changed (their values are still needed to compute the distributions).
"""
function markov_blanket(g::BUGSGraph, v::VarName)
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
function markov_blanket(g::BUGSGraph, v::VarName; children_only=false)
if !children_only
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
for p in children
co_parents = vcat(co_parents, stochastic_inneighbors(g, p))
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
else
return stochastic_outneighbors(g, v)

Check warning on line 249 in src/graphs.jl

View check run for this annotation

Codecov / codecov/patch

src/graphs.jl#L249

Added line #L249 was not covered by tests
end
blanket = unique(vcat(parents, children, co_parents...))
return [x for x in blanket if x != v]
end

function markov_blanket(g::BUGSGraph, v)
function markov_blanket(g::BUGSGraph, v; children_only=false)
blanket = VarName[]
for vn in v
blanket = vcat(blanket, markov_blanket(g, vn))
blanket = vcat(blanket, markov_blanket(g, vn; children_only=children_only))
end
return [x for x in unique(blanket) if x ∉ v]
end
Expand Down
8 changes: 0 additions & 8 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ function LogDensityProblems.dimension(model::BUGSModel)
end
end

function LogDensityProblems.dimension(model::MarkovBlanketCoveredBUGSModel)
return if model.transformed
model.mb_transformed_param_length
else
model.mb_untransformed_param_length
end
end

function LogDensityProblems.capabilities(::AbstractBUGSModel)
return LogDensityProblems.LogDensityOrder{0}()
end
Loading
Loading