From a650f9b7d25abcdd710f6b422fa638e4e964c907 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Date: Mon, 17 Apr 2023 10:23:06 +0100 Subject: [PATCH] Refactor MALA so that it uses the MH component in proposals.py (#523) * refactor MALA so that it uses the MH component in proposals.py * Update blackjax/mcmc/mala.py Co-authored-by: Junpeng Lao --------- Co-authored-by: Junpeng Lao --- blackjax/mcmc/mala.py | 34 +++++++++++++----------------- blackjax/mcmc/proposal.py | 44 +++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 21cdf6656..8ca7521fb 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import blackjax.mcmc.diffusions as diffusions +import blackjax.mcmc.proposal as proposal from blackjax.types import PRNGKey, PyTree __all__ = ["MALAState", "MALAInfo", "init", "kernel"] @@ -74,8 +75,8 @@ def kernel(): """ - def transition_probability(state, new_state, step_size): - """Transition probability to go from `state` to `new_state`""" + def transition_energy(state, new_state, step_size): + """Transition energy to go from `state` to `new_state`""" theta = jax.tree_util.tree_map( lambda new_x, x, g: new_x - x - step_size * g, new_state.position, @@ -85,7 +86,12 @@ def transition_probability(state, new_state, step_size): theta_dot = jax.tree_util.tree_reduce( operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta) ) - return -0.25 * (1.0 / step_size) * theta_dot + return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot + + init_proposal, generate_proposal = proposal.asymmetric_proposal_generator( + transition_energy, divergence_threshold=jnp.inf + ) + sample_proposal = proposal.static_binomial_sampling def one_step( rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float @@ -97,26 +103,16 @@ def one_step( key_integrator, key_rmh = jax.random.split(rng_key) new_state = integrator(key_integrator, state, step_size) + new_state = MALAState(*new_state) - delta = ( - new_state.logdensity - - state.logdensity - + transition_probability(new_state, state, step_size) - - transition_probability(state, new_state, step_size) + proposal = init_proposal(state) + new_proposal, _ = generate_proposal(state, new_state, step_size=step_size) + sampled_proposal, do_accept, p_accept = sample_proposal( + key_rmh, proposal, new_proposal ) - delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta) - p_accept = jnp.clip(jnp.exp(delta), a_max=1) - - do_accept = jax.random.bernoulli(key_rmh, p_accept) - new_state = MALAState(*new_state) info = MALAInfo(p_accept, do_accept) - return jax.lax.cond( - do_accept, - lambda _: (new_state, info), - lambda _: (state, info), - operand=None, - ) + return sampled_proposal.state, info return one_step diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 8231e4d21..3549def79 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -15,7 +15,6 @@ import jax import jax.numpy as jnp -import numpy as np TrajectoryState = NamedTuple @@ -49,18 +48,18 @@ def proposal_generator( Parameters ---------- energy - A callable that computes the energy associated to a given state + A function that computes the energy associated to a given state divergence_threshold - max value allowed for the difference in energies not to be considered a divergence + max value allowed for the difference in energies not to be considered a divergence Returns ------- - Two callables, to generate an initial proposal when no step has been taken, - and to generate proposals after each step. + Two functions, one to generate an initial proposal when no step has been taken, + another to generate proposals after each step. """ def new(state: TrajectoryState) -> Proposal: - return Proposal(state, energy(state), 0.0, -np.inf) + return Proposal(state, energy(state), 0.0, -jnp.inf) def update(initial_energy: float, state: TrajectoryState) -> Tuple[Proposal, bool]: """Generate a new proposal from a trajectory state. @@ -103,13 +102,13 @@ def proposal_from_energy_diff( Parameters ---------- initial_energy - the energy from the previous state + the energy from the initial state new_energy - the energy at the new state + the energy at the proposed state divergence_threshold - max value allowed for the difference in energies not to be considered a divergence + max value allowed for the difference in energies not to be considered a divergence state - the state to propose + the proposed state Returns ------- @@ -139,7 +138,7 @@ def proposal_from_energy_diff( def asymmetric_proposal_generator( transition_energy_fn: Callable, divergence_threshold: float, - proposal_factory=proposal_from_energy_diff, + proposal_factory: Callable = proposal_from_energy_diff, ) -> Tuple[Callable, Callable]: """A proposal generator that takes into account the transition between two states to compute a new proposal. In particular, both states are @@ -147,28 +146,29 @@ def asymmetric_proposal_generator( to account for asymmetries. ---------- transition_energy_fn - A Callable that computes the energy of a associated with a transition - from one state to another + A function that computes the energy of a transition from an initial state + to a new state, given some optional keyword arguments. divergence_threshold - A max number to will be used by the proposal_factory to flag a Proposal - as a divergence. + The maximum value allowed for the difference in energies not to be considered a divergence. proposal_factory - A callable that builds a proposal from the transitions energies + A function that builds a proposal from the transition energies. Returns ------- - Two callables, to generate an initial proposal when no step has been taken, - and to generate proposals after each step. + Two functions, one to generate an initial proposal when no step has been taken, + another to generate proposals after each step. """ def new(state: TrajectoryState) -> Proposal: - return Proposal(state, 0.0, 0.0, -np.inf) + return Proposal(state, 0.0, 0.0, -jnp.inf) def update( - initial_state: TrajectoryState, state: TrajectoryState + initial_state: TrajectoryState, + state: TrajectoryState, + **energy_params, ) -> Tuple[Proposal, bool]: - new_energy = transition_energy_fn(initial_state, state) - prev_energy = transition_energy_fn(state, initial_state) + new_energy = transition_energy_fn(initial_state, state, **energy_params) + prev_energy = transition_energy_fn(state, initial_state, **energy_params) return proposal_factory(prev_energy, new_energy, divergence_threshold, state) return new, update