Skip to content

Commit

Permalink
NaN Handling (#727)
Browse files Browse the repository at this point in the history
* bug fix; first part

* bug fix; first part

* further debug

* remove print statements

* handle logdensity nans. mask -> 1 - mask.
  • Loading branch information
reubenharry committed Aug 26, 2024
1 parent b02b60b commit 8a9b546
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import incremental_value_update, pytree_size
from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key):

time, x_average, step_size_max = adaptive_state

rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
Expand All @@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
params.step_size,
step_size_max,
info.energy_change,
nan_key,
)

# Warning: var = 0 if there were nans, but we will give it a very small weight
Expand Down Expand Up @@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key):
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
weight=mask * success * params.step_size,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
)

# we use the last num_steps2 to compute the diagonal preconditioner
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = run_steps(
Expand Down Expand Up @@ -298,7 +301,9 @@ def step(state, key):
return adaptation_L


def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
def handle_nans(
previous_state, next_state, step_size, step_size_max, kinetic_change, key
):
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

Expand All @@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

state = jax.lax.cond(
jnp.isnan(next_state.logdensity),
lambda: state._replace(
momentum=generate_unit_vector(key, previous_state.position)
),
lambda: state,
)

return nonans, state, step_size, kinetic_change

0 comments on commit 8a9b546

Please sign in to comment.