Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of logdensity computation #228

Merged
merged 7 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@ struct NodeInfo{F}
loop_vars::NamedTuple
end

"""
BUGSGraph

The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for
`MetaGraphsNext.MetaGraph`.
"""
const BUGSGraph = MetaGraph
const BUGSGraph = MetaGraph{
Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64
}

"""
find_generated_vars(g::BUGSGraph)
Expand Down
78 changes: 40 additions & 38 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ Return a vector of `VarName` containing the names of all the variables in the mo
"""
variables(m::BUGSModel) = collect(labels(m.g))

function prepare_arg_values(
args::Tuple{Vararg{Symbol}}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {lvars}
return NamedTuple{args}(Tuple(
map(args) do arg
if arg in lvars
loop_vars[arg]
else
AbstractPPL.get(evaluation_env, @varname($arg))
end
end,
))
@generated function prepare_arg_values(
::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {args,lvars}
fields = []
for arg in args
if arg in lvars
push!(fields, :(loop_vars[$(QuoteNode(arg))]))
else
push!(fields, :(evaluation_env[$(QuoteNode(arg))]))
end
end
return :(NamedTuple{$(args)}(($(fields...),)))
end

function BUGSModel(
Expand All @@ -99,7 +99,7 @@ function BUGSModel(

for vn in sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
Expand Down Expand Up @@ -179,7 +179,7 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple)
check_input(initial_params)
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
BangBang.@set!! model.evaluation_env = setindex!!(
Expand Down Expand Up @@ -243,7 +243,7 @@ function getparams(model::BUGSModel)
end
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
transformed_value = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
Expand All @@ -267,7 +267,7 @@ function getparams_as_ordereddict(model::BUGSModel)
d[v] = AbstractPPL.get(model.evaluation_env, v)
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
Expand Down Expand Up @@ -321,7 +321,20 @@ function AbstractPPL.condition(
)
end

return BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env)
g = copy(model.g)
for vn in sorted_blanket_with_vars
if vn in new_parameters
continue
end
ni = g[vn]
if ni.is_stochastic && !ni.is_observed
ni = @set ni.is_observed = true
g[vn] = ni
end
end

new_model = BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env)
return BangBang.setproperty!!(new_model, :g, g)
end

function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
Expand Down Expand Up @@ -387,7 +400,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -410,7 +423,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -436,51 +449,40 @@ end
function AbstractPPL.evaluate!!(
model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector
)
param_lengths = if model.transformed
model.transformed_param_length
else
model.untransformed_param_length
end

if length(flattened_values) != param_lengths
error(
"The length of `flattened_values` does not match the length of the parameters in the model",
)
end

var_lengths = if model.transformed
model.transformed_var_lengths
else
model.untransformed_var_lengths
end

sorted_nodes = model.sorted_nodes
g = model.g
evaluation_env = deepcopy(model.evaluation_env)
current_idx = 1
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
if vn in model.parameters
if !is_observed
l = var_lengths[vn]
if model.transformed
b = Bijectors.bijector(dist)
b_inv = Bijectors.inverse(b)
reconstructed_value = reconstruct(
b_inv, dist, flattened_values[current_idx:(current_idx + l - 1)]
b_inv,
dist,
view(flattened_values, current_idx:(current_idx + l - 1)),
)
value, logjac = Bijectors.with_logabsdet_jacobian(
b_inv, reconstructed_value
)
else
value = reconstruct(
dist, flattened_values[current_idx:(current_idx + l - 1)]
dist, view(flattened_values, current_idx:(current_idx + l - 1))
)
logjac = 0.0
end
Expand Down
3 changes: 2 additions & 1 deletion test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ decond_model = AbstractPPL.decondition(cond_model, [a, l])
c_value = 4.0
mb_logp = begin
logp = 0
logp += logpdf(dnorm(1.0, c_value), 1.0) # a
f = 2.0 - 1.0
logp += logpdf(dnorm(f, c_value), 1.0) # a
logp += logpdf(dnorm(0.0, 1.0), 2.0) # b
logp += logpdf(dnorm(0.0, 1.0), -2.0) # l
logp += logpdf(dnorm(-2.0, 1.0), c_value) # c
Expand Down
Loading