Skip to content

Commit

Permalink
Merge pull request #12 from roualdes/feature/11-julia-cleanup
Browse files Browse the repository at this point in the history
[Julia] Interface Improvements
  • Loading branch information
roualdes authored Sep 13, 2022
2 parents 8e2c286 + 5cac08d commit b663d2c
Show file tree
Hide file tree
Showing 9 changed files with 834 additions and 269 deletions.
17 changes: 9 additions & 8 deletions julia/MCMC.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct HMCDiag
model::Bridgestan.StanModel
model::BridgeStan.StanModel
stepsize::Float64
steps::Int64
metric::Vector{Float64}
Expand All @@ -11,22 +11,23 @@ function HMCDiag(model, stepsize, steps)
model,
stepsize,
steps,
ones(Bridgestan.param_unc_num(model)),
randn(Bridgestan.param_unc_num(model)))
ones(BridgeStan.param_unc_num(model)),
randn(BridgeStan.param_unc_num(model)),
)
end

function joint_logp(hmc::HMCDiag, theta, rho)
logp, _ = Bridgestan.log_density_gradient(hmc.model, theta)
logp, _ = BridgeStan.log_density_gradient(hmc.model, theta)
return logp - 0.5 * rho' * (hmc.metric .* rho)
end

function leapfrog(hmc::HMCDiag, theta, rho)
e = hmc.stepsize .* hmc.metric
lp, grad = Bridgestan.log_density_gradient(hmc.model, theta)
lp, grad = BridgeStan.log_density_gradient(hmc.model, theta)
rho_p = rho + 0.5 * hmc.stepsize .* grad
for n in 1:hmc.steps
for n = 1:hmc.steps
theta .+= e .* rho_p
lp, grad = Bridgestan.log_density_gradient(hmc.model, theta)
lp, grad = BridgeStan.log_density_gradient(hmc.model, theta)
if n != hmc.steps
rho_p .+= e .* grad
end
Expand All @@ -36,7 +37,7 @@ function leapfrog(hmc::HMCDiag, theta, rho)
end

function sample(hmc::HMCDiag)
rho = randn(Bridgestan.param_unc_num(model))
rho = randn(BridgeStan.param_unc_num(model))
logp = joint_logp(hmc, hmc.theta, rho)
theta_prop, rho_prop = leapfrog(hmc, hmc.theta, rho)
logp_prop = joint_logp(hmc, theta_prop, rho_prop)
Expand Down
2 changes: 1 addition & 1 deletion julia/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "Bridgestan"
name = "BridgeStan"
uuid = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
authors = ["Brian Ward <bward@flatironinstitute.org>", "Bob Carpenter <bcarpenter@flatironinstitute.org", "Edward Roualdes <eroualdes@csuchico.edu>"]
version = "0.1.0"
22 changes: 12 additions & 10 deletions julia/example.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
using Bridgestan
using BridgeStan

const BS = BridgeStan

# Bernoulli
# CMDSTAN=/path/to/cmdstan/ make stan/bernoulli/bernoulli

bernoulli_lib = joinpath(@__DIR__, "../stan/bernoulli/bernoulli_model.so")
bernoulli_data = joinpath(@__DIR__, "../stan/bernoulli/bernoulli.data.json")

smb = Bridgestan.StanModel(bernoulli_lib, bernoulli_data);
x = rand(Bridgestan.param_unc_num(smb));
smb = BS.StanModel(bernoulli_lib, bernoulli_data);
x = rand(BS.param_unc_num(smb));
q = @. log(x / (1 - x)); # unconstrained scale

lp, grad = Bridgestan.log_density_gradient(smb, q, jacobian = 0)
lp, grad = BS.log_density_gradient(smb, q, jacobian = 0)

println()
println("log_density and gradient of Bernoulli model:")
Expand All @@ -25,10 +27,10 @@ println()
multi_lib = joinpath(@__DIR__, "../stan/multi/multi_model.so")
multi_data = joinpath(@__DIR__, "../stan/multi/multi.data.json")

smm = Bridgestan.StanModel(multi_lib, multi_data)
x = randn(Bridgestan.param_unc_num(smm));
smm = BS.StanModel(multi_lib, multi_data)
x = randn(BS.param_unc_num(smm));

lp, grad = Bridgestan.log_density_gradient(smm, x)
lp, grad = BS.log_density_gradient(smm, x)

println("log_density and gradient of Multivariate Gaussian model:")
println((lp, grad))
Expand All @@ -39,15 +41,15 @@ println()
include("./MCMC.jl")
using Statistics

model = Bridgestan.StanModel(multi_lib, multi_data);
model = BS.StanModel(multi_lib, multi_data);

stepsize = 0.25
steps = 10
hmcd = HMCDiag(model, stepsize, steps);

M = 10_000
theta = zeros(M, Bridgestan.param_unc_num(model))
for m in 1:M
theta = zeros(M, BS.param_unc_num(model))
for m = 1:M
theta[m, :] .= sample(hmcd)
end

Expand Down
Loading

0 comments on commit b663d2c

Please sign in to comment.