From 597657e332ab13a4882eebdd425843f053ff8c56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 7 Jun 2022 13:54:55 +0200 Subject: [PATCH] Make MALA work with any PyTree --- blackjax/mcmc/mala.py | 11 +++++++++-- tests/test_sampling.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index e85211288..f809c2d4e 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.mcmc.diffusion import overdamped_langevin from blackjax.types import PRNGKey, PyTree @@ -62,8 +63,14 @@ def kernel(): def transition_probability(state, new_state, step_size): """Transition probability to go from `state` to `new_state`""" - theta = new_state.position - state.position - step_size * state.logprob_grad - return -0.25 * (1.0 / step_size) * jnp.dot(theta, theta) + theta = jax.tree_util.tree_map( + lambda new_x, x, g: new_x - x - step_size * g, + new_state.position, + state.position, + state.logprob_grad, + ) + theta_ravel, _ = ravel_pytree(theta) + return -0.25 * (1.0 / step_size) * jnp.dot(theta_ravel, theta_ravel) def one_step( rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 556c614ba..efdb9c54a 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -106,6 +106,29 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + def test_mala(self): + """Test the MALA kernel.""" + rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logposterior_fn = lambda x: logposterior_fn_(**x) + + warmup_key, inference_key = jax.random.split(rng_key, 2) + + mala = blackjax.mala(logposterior_fn, 1e-5) + state = mala.init({"coefs": 1.0, "scale": 2.0}) + states = inference_loop(mala.step, 10_000, inference_key, state) + + coefs_samples = states.position["coefs"][3000:] + scale_samples = states.position["scale"][3000:] + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + @parameterized.parameters(regresion_test_cases) def test_pathfinder_adaptation( self,