Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates HMC sampler to recent TFP version #46

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax_lensing/samplers/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import jax.numpy as jnp
import jax
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
import tensorflow_probability as tfp; tfp = tfp.substrates.jax

def tempered_HMC(init_image,
total_score_fn,
Expand Down
18 changes: 10 additions & 8 deletions jax_lensing/samplers/score_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import jax
import jax.numpy as jnp
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
from tensorflow_probability.python.mcmc.internal._jax import util as mcmc_util
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
from tensorflow_probability.python.mcmc.internal import util as mcmc_util

__all__ = [
'ScoreUncalibratedHamiltonianMonteCarlo',
Expand All @@ -24,8 +24,8 @@ def __init__(self,
num_delta_logp_steps,
target_log_prob_fn=None,
state_gradients_are_stopped=False,
seed=None,
store_parameters_in_results=False,
experimental_shard_axis_names=None,
name=None):

if target_log_prob_fn is None:
Expand All @@ -46,9 +46,10 @@ def fake_logp_jvp(primals, tangents):
super().__init__(target_log_prob_fn,
step_size,
num_leapfrog_steps,
state_gradients_are_stopped,
seed,
store_parameters_in_results, name)
state_gradients_are_stopped=state_gradients_are_stopped,
name=name,
experimental_shard_axis_names=experimental_shard_axis_names,
store_parameters_in_results=store_parameters_in_results)
self._parameters['target_score_fn'] = target_score_fn
self._parameters['num_delta_logp_steps'] = num_delta_logp_steps

Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self,
target_log_prob_fn=None,
seed=None,
store_parameters_in_results=False,
experimental_shard_axis_names=None,
name=None):
"""Initializes this transition kernel.
Args:
Expand Down Expand Up @@ -194,7 +196,7 @@ def __init__(self,
name=name or 'hmc_kernel',
store_parameters_in_results=store_parameters_in_results,
**uhmc_kwargs),
**mh_kwargs)
**mh_kwargs).experimental_with_shard_axes(experimental_shard_axis_names)
self._parameters = self._impl.inner_kernel.parameters.copy()
self._parameters['step_size_update_fn'] = step_size_update_fn
self._parameters['seed'] = seed
Expand Down Expand Up @@ -412,4 +414,4 @@ def simps(f, a, b, N=128):
x = jnp.linspace(a, b, N + 1)
y = f(x)
S = dx / 3 * jnp.sum(y[0:-1:2] + 4 * y[1::2] + y[2::2], axis=0)
return S
return S