Skip to content

Commit

Permalink
Make MALA work with any PyTree
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 7, 2022
1 parent 07c345a commit 597657e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
11 changes: 9 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.mcmc.diffusion import overdamped_langevin
from blackjax.types import PRNGKey, PyTree
Expand Down Expand Up @@ -62,8 +63,14 @@ def kernel():

def transition_probability(state, new_state, step_size):
"""Transition probability to go from `state` to `new_state`"""
theta = new_state.position - state.position - step_size * state.logprob_grad
return -0.25 * (1.0 / step_size) * jnp.dot(theta, theta)
theta = jax.tree_util.tree_map(
lambda new_x, x, g: new_x - x - step_size * g,
new_state.position,
state.position,
state.logprob_grad,
)
theta_ravel, _ = ravel_pytree(theta)
return -0.25 * (1.0 / step_size) * jnp.dot(theta_ravel, theta_ravel)

def one_step(
rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float
Expand Down
23 changes: 23 additions & 0 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)

def test_mala(self):
"""Test the MALA kernel."""
rng_key, init_key0, init_key1 = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)

logposterior_fn_ = functools.partial(
self.regression_logprob, x=x_data, preds=y_data
)
logposterior_fn = lambda x: logposterior_fn_(**x)

warmup_key, inference_key = jax.random.split(rng_key, 2)

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "scale": 2.0})
states = inference_loop(mala.step, 10_000, inference_key, state)

coefs_samples = states.position["coefs"][3000:]
scale_samples = states.position["scale"][3000:]

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1)

@parameterized.parameters(regresion_test_cases)
def test_pathfinder_adaptation(
self,
Expand Down

0 comments on commit 597657e

Please sign in to comment.