Skip to content

Commit

Permalink
Fix bug in run_inference_algorithm and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulScemama committed Dec 9, 2023
1 parent f150f35 commit 9b4daf7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
1 change: 0 additions & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def run_inference_algorithm(
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)

Expand Down
50 changes: 50 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import chex
import jax
import jax.numpy as jnp
from absl.testing import absltest, parameterized

from blackjax.mcmc.hmc import hmc
from blackjax.util import run_inference_algorithm


class RunInferenceAlgorithmTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(42)
self.algorithm = hmc(
logdensity_fn=self.logdensity_fn,
inverse_mass_matrix=jnp.eye(2),
step_size=1.0,
num_integration_steps=1000,
)
self.num_steps = 10

def check_compatible(self, initial_state_or_position, progress_bar):
"""
Runs 10 steps with `run_inference_algorithm` starting with
`initial_state_or_position` and potentially a progress bar.
"""
_ = run_inference_algorithm(
self.key,
initial_state_or_position,
self.algorithm,
self.num_steps,
progress_bar,
)

@parameterized.parameters([True, False])
def test_compatible_with_initial_pos(self, progress_bar):
self.check_compatible(jnp.array([1.0, 1.0]), progress_bar)

@parameterized.parameters([True, False])
def test_compatible_with_initial_state(self, progress_bar):
state = self.algorithm.init(jnp.array([1.0, 1.0]))
self.check_compatible(state, progress_bar)

@staticmethod
def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))


if __name__ == "__main__":
absltest.main()

0 comments on commit 9b4daf7

Please sign in to comment.