Skip to content

Commit

Permalink
Resolve: Add progress bar to run_inference_algorithm (#614)
Browse files Browse the repository at this point in the history
* Add progress bar to run_inference_algorithm

* Add progress_bar as kwarg defaulting to False

* Fix bug in run_inference_algorithm and add test

---------

Co-authored-by: Paul Scemama <pscemama@mitre.org>
  • Loading branch information
PaulScemama and pscemama-mitre authored Dec 9, 2023
1 parent 540db41 commit 08e0d75
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
14 changes: 11 additions & 3 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax.tree_util import tree_leaves

from blackjax.base import Info, State
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


Expand Down Expand Up @@ -144,6 +145,7 @@ def run_inference_algorithm(
initial_state_or_position,
inference_algorithm,
num_steps,
progress_bar: bool = False,
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Expand Down Expand Up @@ -171,14 +173,20 @@ def run_inference_algorithm(
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)

@jit
def one_step(state, rng_key):
def _one_step(state, xs):
_, rng_key = xs
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)

final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
if progress_bar:
one_step = progress_bar_scan(num_steps)(_one_step)
else:
one_step = _one_step

xs = (jnp.arange(num_steps), keys)
final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs)
return final_state, state_history, info_history
50 changes: 50 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import chex
import jax
import jax.numpy as jnp
from absl.testing import absltest, parameterized

from blackjax.mcmc.hmc import hmc
from blackjax.util import run_inference_algorithm


class RunInferenceAlgorithmTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(42)
self.algorithm = hmc(
logdensity_fn=self.logdensity_fn,
inverse_mass_matrix=jnp.eye(2),
step_size=1.0,
num_integration_steps=1000,
)
self.num_steps = 10

def check_compatible(self, initial_state_or_position, progress_bar):
"""
Runs 10 steps with `run_inference_algorithm` starting with
`initial_state_or_position` and potentially a progress bar.
"""
_ = run_inference_algorithm(
self.key,
initial_state_or_position,
self.algorithm,
self.num_steps,
progress_bar,
)

@parameterized.parameters([True, False])
def test_compatible_with_initial_pos(self, progress_bar):
self.check_compatible(jnp.array([1.0, 1.0]), progress_bar)

@parameterized.parameters([True, False])
def test_compatible_with_initial_state(self, progress_bar):
state = self.algorithm.init(jnp.array([1.0, 1.0]))
self.check_compatible(state, progress_bar)

@staticmethod
def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))


if __name__ == "__main__":
absltest.main()

0 comments on commit 08e0d75

Please sign in to comment.