Skip to content

Commit

Permalink
Merge pull request #2 from Sahel13/mcmc_bs
Browse files Browse the repository at this point in the history
MCMC-based backward sampling
  • Loading branch information
Sahel13 authored Sep 10, 2024
2 parents 2369728 + ac1ccbe commit 6acd2ce
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 23 deletions.
6 changes: 4 additions & 2 deletions experiments/pendulum/linear/io_csmc_sysid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ nb_csmc_moves = 1

nb_iter = 25
opt_state = Flux.setup(Flux.Optimise.Adam(1e-3), learner_loop)
batch_size = 64
batch_size = 16

Flux.reset!(learner_loop.ctl)
state_struct, param_struct = smc_with_ibis_marginal_dynamics(
Expand All @@ -124,6 +124,8 @@ reference = IBISReference(
param_struct.log_likelihoods[:, :, idx]
)

backward_sample = true

learner_loop, _ = markovian_score_climbing_with_ibis_marginal_dynamics(
nb_iter,
opt_state,
Expand All @@ -140,7 +142,7 @@ learner_loop, _ = markovian_score_climbing_with_ibis_marginal_dynamics(
tempering,
reference,
nb_csmc_moves,
true,
backward_sample,
param_proposal,
nb_ibis_moves,
true
Expand Down
126 changes: 105 additions & 21 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,28 @@ function design_logpdfs(
state_struct::StateStruct,
future_trajectory::Matrix{Float64},
closed_loop::IBISClosedLoop,
time_idx::Int
time_idx::Int,
indices::Union{Nothing, Vector{Int}} = nothing
)
"""
Computes the logpdfs of all future designs given the current trajectories.
"""
if indices === nothing
indices = 1:state_struct.nb_trajectories
end

# Construct the trajectories.
trajectories = Array{Float64}(undef, state_struct.state_dim, state_struct.nb_steps + 1, state_struct.nb_trajectories)
for traj_idx = 1:state_struct.nb_trajectories
trajectories = Array{Float64}(undef, state_struct.state_dim, state_struct.nb_steps + 1, length(indices))
for (i, traj_idx) in enumerate(indices)
# Get $z_{0:t}^n$ by tracing genealogies.
trajectories[:, 1:time_idx, traj_idx] = genealogy_tracer(
trajectories[:, 1:time_idx, i] = genealogy_tracer(
state_struct.unresampled_trajectories,
state_struct.resampled_idx,
time_idx,
traj_idx
)
# $\bar{z}_{t+1:T}$ is the same for all trajectories.
trajectories[:, time_idx + 1:end, traj_idx] = future_trajectory
trajectories[:, time_idx + 1:end, i] = future_trajectory
end

# Feed the states $z_{0:t-1}^n$ to the policy.
Expand All @@ -70,7 +75,7 @@ function design_logpdfs(
# Compute the densities $\pi(\bar{\xi}_s \mid z_{0:t}^n, \bar{z}_{t+1:s-1})$
# for all $s = t+1, \dots, T$.
xdim = closed_loop.dyn.xdim
logpdfs = zeros(state_struct.nb_trajectories)
logpdfs = zeros(length(indices))
for t = time_idx:state_struct.nb_steps
zs = trajectories[:, t, :]
us = trajectories[xdim+1:end, t+1, :]
Expand Down Expand Up @@ -126,15 +131,16 @@ function theta_transition_density(
end


function backward_sampling_single(
function backward_sampling(
state_struct::StateStruct,
param_struct::IBISParamStruct,
closed_loop::IBISClosedLoop,
eta::Float64,
idx::Union{Int, Nothing} = nothing
)
"""
Backward sampling function that returns a single trajectory.
Standard backward sampling from Godsill et al. (2004).
Returns a single trajectory.
"""
trajectory = Matrix{Float64}(undef, state_struct.state_dim, state_struct.nb_steps + 1)
traj_indices = Vector{Int}(undef, state_struct.nb_steps + 1)
Expand Down Expand Up @@ -190,20 +196,94 @@ function backward_sampling_single(
end


function backward_sampling(
function backward_sampling_mcmc(
state_struct::StateStruct,
param_struct::IBISParamStruct,
closed_loop::IBISClosedLoop,
eta::Float64,
num_trajs::Int = 16
idx::Union{Int, Nothing} = nothing
)
"""
MCMC-based backward sampling from Bunch and Godsill (2013).
This is the recommended smoothing algorithm by Dau and Chopin (2023).
Returns a single trajectory.
Reference for implementation:
https://github.com/nchopin/particles/blob/841cf363b3f1dee0faa77f6a0349ace3477917ab/particles/smoothing.py#L313
"""
trajectory = Matrix{Float64}(undef, state_struct.state_dim, state_struct.nb_steps + 1)
traj_indices = Vector{Int}(undef, state_struct.nb_steps + 1)

# Sample a particle at the final time step.
if idx === nothing
idx = rand(Categorical(state_struct.weights[:, end]))
end
traj_indices[end] = idx
trajectory[:, end] = state_struct.unresampled_trajectories[:, end, idx]

# Work our way backwards.
for t = state_struct.nb_steps:-1:1
ancestor_idx = state_struct.resampled_idx[idx, t]
proposed_idx = rand(Categorical(state_struct.weights[:, t]))
indices = [ancestor_idx, proposed_idx]

# Compute the probability of all future designs.
design_densities = design_logpdfs(
state_struct,
trajectory[:, t + 1:end],
closed_loop,
t,
indices
)
# Compute the potential function and the state transition density.
dynamics = closed_loop.dyn
zs = state_struct.unresampled_trajectories[:, t, indices]
zn = trajectory[:, t + 1]
x_prob_and_pot = map(indices, eachcol(zs)) do anc, z
ps = param_struct.raw_particles[:, t, :, anc]
state_transition_and_potential(
z, zn, ps, dynamics, view(param_struct.scratch, :, :, anc), eta
)
end
## Compute the transition probability of the theta particles.
theta_prob = map(indices, eachcol(zs)) do anc, z
theta_transition_density(
param_struct.raw_particles[:, t, :, anc],
param_struct.raw_particles[:, t + 1, :, idx],
dynamics,
z,
zn,
param_struct.nb_particles,
view(param_struct.scratch, :, :, anc)
)
end

reweighting_ratio = x_prob_and_pot .+ design_densities .+ theta_prob
lpr_acc = reweighting_ratio[2] - reweighting_ratio[1]
lu = log(rand())
idx = lpr_acc > lu ? proposed_idx : ancestor_idx

traj_indices[t] = idx
trajectory[:, t] = state_struct.unresampled_trajectories[:, t, idx]
end
return trajectory, traj_indices
end


function backward_sampling_batched(
state_struct::StateStruct,
param_struct::IBISParamStruct,
closed_loop::IBISClosedLoop,
eta::Float64,
num_trajs::Int
)
"""
Backward sampling function that returns `num_trajs` trajectories.
Current default is to use MCMC-based backward sampling.
"""
indices = rand(Categorical(state_struct.weights[:, end]), num_trajs)
trajectories = Array{Float64}(undef, state_struct.state_dim, state_struct.nb_steps + 1, num_trajs)
for traj_idx = 1:num_trajs
trajectories[:, :, traj_idx], _ = backward_sampling_single(
Threads.@threads for traj_idx = 1:num_trajs
trajectories[:, :, traj_idx], _ = backward_sampling_mcmc(
state_struct,
param_struct,
closed_loop,
Expand Down Expand Up @@ -235,7 +315,10 @@ function markovian_score_climbing_with_ibis_marginal_dynamics(
backward_sample::Bool = false,
param_proposal::Union{T, Nothing} = nothing,
nb_ibis_moves::Union{Int, Nothing} = nothing,
verbose::Bool = false
verbose::Bool = false,
nb_trajectories_eval::Int = 256,
nb_particles_eval::Int = 256,
nb_backward_samples::Int = 32
) where {T<:Function}

if !backward_sample
Expand All @@ -251,8 +334,8 @@ function markovian_score_climbing_with_ibis_marginal_dynamics(
Flux.reset!(evaluator.ctl)
state_struct, param_struct = smc_with_ibis_marginal_dynamics(
nb_steps,
nb_trajectories,
nb_particles,
nb_trajectories_eval,
nb_particles_eval,
init_state,
evaluator,
param_prior,
Expand Down Expand Up @@ -293,7 +376,7 @@ function markovian_score_climbing_with_ibis_marginal_dynamics(
!backward_sample
)
if backward_sample
trajectory, traj_indices = backward_sampling_single(
trajectory, traj_indices = backward_sampling_mcmc(
state_struct,
param_struct,
learner,
Expand Down Expand Up @@ -326,15 +409,16 @@ function markovian_score_climbing_with_ibis_marginal_dynamics(
end

if backward_sample
trajectories = backward_sampling(
trajectories = backward_sampling_batched(
state_struct,
param_struct,
learner,
tempering
tempering,
nb_backward_samples
)
samples = trajectories
else
idx = rand(Categorical(state_struct.weights[:, end]), nb_trajectories)
idx = rand(Categorical(state_struct.weights[:, end]), nb_backward_samples)
samples = state_struct.trajectories[:, :, idx]
end

Expand All @@ -357,8 +441,8 @@ function markovian_score_climbing_with_ibis_marginal_dynamics(
Flux.reset!(evaluator.ctl)
state_struct, _ = smc_with_ibis_marginal_dynamics(
nb_steps,
nb_trajectories,
nb_particles,
nb_trajectories_eval,
nb_particles_eval,
init_state,
evaluator,
param_prior,
Expand Down

0 comments on commit 6acd2ce

Please sign in to comment.