From ce29cc1c903384684db5b723484daa9618794372 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Fri, 25 Oct 2024 12:42:40 +0100 Subject: [PATCH] Improve performance of logdensity computation (#228) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/graphs.jl | 10 ++----- src/model.jl | 78 ++++++++++++++++++++++++++------------------------ test/graphs.jl | 3 +- 3 files changed, 45 insertions(+), 46 deletions(-) diff --git a/src/graphs.jl b/src/graphs.jl index 71d8e7cd3..05aae7fbf 100644 --- a/src/graphs.jl +++ b/src/graphs.jl @@ -7,13 +7,9 @@ struct NodeInfo{F} loop_vars::NamedTuple end -""" - BUGSGraph - -The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for -`MetaGraphsNext.MetaGraph`. -""" -const BUGSGraph = MetaGraph +const BUGSGraph = MetaGraph{ + Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64 +} """ find_generated_vars(g::BUGSGraph) diff --git a/src/model.jl b/src/model.jl index 82e8f98c0..874a98a26 100644 --- a/src/model.jl +++ b/src/model.jl @@ -71,18 +71,18 @@ Return a vector of `VarName` containing the names of all the variables in the mo """ variables(m::BUGSModel) = collect(labels(m.g)) -function prepare_arg_values( - args::Tuple{Vararg{Symbol}}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars} -) where {lvars} - return NamedTuple{args}(Tuple( - map(args) do arg - if arg in lvars - loop_vars[arg] - else - AbstractPPL.get(evaluation_env, @varname($arg)) - end - end, - )) +@generated function prepare_arg_values( + ::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars} +) where {args,lvars} + fields = [] + for arg in args + if arg in lvars + push!(fields, :(loop_vars[$(QuoteNode(arg))])) + else + push!(fields, :(evaluation_env[$(QuoteNode(arg))])) + end + end + return :(NamedTuple{$(args)}(($(fields...),))) end function BUGSModel( @@ -99,7 +99,7 @@ function BUGSModel( for vn in sorted_nodes (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) if !is_stochastic value = Base.invokelatest(node_function; args...) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) @@ -179,7 +179,7 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) for vn in model.sorted_nodes (; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn] - args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) if !is_stochastic value = Base.invokelatest(node_function; args...) BangBang.@set!! model.evaluation_env = setindex!!( @@ -243,7 +243,7 @@ function getparams(model::BUGSModel) end else (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) dist = node_function(; args...) transformed_value = Bijectors.transform( Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) @@ -267,7 +267,7 @@ function getparams_as_ordereddict(model::BUGSModel) d[v] = AbstractPPL.get(model.evaluation_env, v) else (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) dist = node_function(; args...) d[v] = Bijectors.transform( Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) @@ -321,7 +321,20 @@ function AbstractPPL.condition( ) end - return BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env) + g = copy(model.g) + for vn in sorted_blanket_with_vars + if vn in new_parameters + continue + end + ni = g[vn] + if ni.is_stochastic && !ni.is_observed + ni = @set ni.is_observed = true + g[vn] = ni + end + end + + new_model = BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env) + return BangBang.setproperty!!(new_model, :g, g) end function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) @@ -387,7 +400,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -410,7 +423,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -436,51 +449,40 @@ end function AbstractPPL.evaluate!!( model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector ) - param_lengths = if model.transformed - model.transformed_param_length - else - model.untransformed_param_length - end - - if length(flattened_values) != param_lengths - error( - "The length of `flattened_values` does not match the length of the parameters in the model", - ) - end - var_lengths = if model.transformed model.transformed_var_lengths else model.untransformed_var_lengths end - sorted_nodes = model.sorted_nodes g = model.g evaluation_env = deepcopy(model.evaluation_env) current_idx = 1 logp = 0.0 - for vn in sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + for vn in model.sorted_nodes + (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] + args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) - if vn in model.parameters + if !is_observed l = var_lengths[vn] if model.transformed b = Bijectors.bijector(dist) b_inv = Bijectors.inverse(b) reconstructed_value = reconstruct( - b_inv, dist, flattened_values[current_idx:(current_idx + l - 1)] + b_inv, + dist, + view(flattened_values, current_idx:(current_idx + l - 1)), ) value, logjac = Bijectors.with_logabsdet_jacobian( b_inv, reconstructed_value ) else value = reconstruct( - dist, flattened_values[current_idx:(current_idx + l - 1)] + dist, view(flattened_values, current_idx:(current_idx + l - 1)) ) logjac = 0.0 end diff --git a/test/graphs.jl b/test/graphs.jl index 124515be4..c8499752e 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -57,7 +57,8 @@ decond_model = AbstractPPL.decondition(cond_model, [a, l]) c_value = 4.0 mb_logp = begin logp = 0 - logp += logpdf(dnorm(1.0, c_value), 1.0) # a + f = 2.0 - 1.0 + logp += logpdf(dnorm(f, c_value), 1.0) # a logp += logpdf(dnorm(0.0, 1.0), 2.0) # b logp += logpdf(dnorm(0.0, 1.0), -2.0) # l logp += logpdf(dnorm(-2.0, 1.0), c_value) # c