-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the marginal gradient sampler for latent gaussian models (#247)
- Loading branch information
1 parent
79a32ae
commit 2da32c7
Showing
8 changed files
with
857 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,19 @@ | ||
from . import elliptical_slice, hmc, mala, nuts, periodic_orbital, rmh | ||
from . import ( | ||
elliptical_slice, | ||
hmc, | ||
mala, | ||
marginal_latent_gaussian, | ||
nuts, | ||
periodic_orbital, | ||
rmh, | ||
) | ||
|
||
__all__ = ["elliptical_slice", "hmc", "mala", "nuts", "periodic_orbital", "rmh"] | ||
__all__ = [ | ||
"elliptical_slice", | ||
"hmc", | ||
"mala", | ||
"nuts", | ||
"periodic_orbital", | ||
"rmh", | ||
"marginal_latent_gaussian", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Public API for marginal latent Gaussian sampling.""" | ||
from typing import NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy.linalg as linalg | ||
|
||
from blackjax.types import Array, PRNGKey | ||
|
||
__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel"] | ||
|
||
|
||
class MarginalState(NamedTuple): | ||
"""State of the RMH chain. | ||
x | ||
Current position of the chain. | ||
log_p_x | ||
Current value of the log-likelihood of the model | ||
grad_x | ||
Current value of the gradient of the log-likelihood of the model | ||
Other Attributes: | ||
----------------- | ||
U_x, U_grad_x: Array | ||
Auxiliary attributes | ||
""" | ||
|
||
position: Array | ||
logprob: float | ||
logprob_grad: Array | ||
|
||
U_x: Array | ||
U_grad_x: Array | ||
|
||
|
||
class MarginalInfo(NamedTuple): | ||
"""Additional information on the RMH chain. | ||
This additional information can be used for debugging or computing | ||
diagnostics. | ||
acceptance_probability | ||
The acceptance probability of the transition, linked to the energy | ||
difference between the original and the proposed states. | ||
is_accepted | ||
Whether the proposed position was accepted or the original position | ||
was returned. | ||
proposal | ||
The state proposed by the proposal. | ||
""" | ||
|
||
acceptance_probability: float | ||
is_accepted: bool | ||
proposal: MarginalState | ||
|
||
|
||
def init_and_kernel(logprob_fn, covariance, mean=None): | ||
"""Build the marginal version of the auxiliary gradient-based sampler | ||
Returns | ||
------- | ||
A kernel that takes a rng_key and a Pytree that contains the current state | ||
of the chain and that returns a new state of the chain along with | ||
information about the transition. | ||
An init function. | ||
""" | ||
|
||
U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True) | ||
|
||
if mean is not None: | ||
shift = linalg.solve(covariance, mean, sym_pos=True) | ||
val_and_grad = jax.value_and_grad(lambda x: logprob_fn(x) + jnp.dot(x, shift)) | ||
else: | ||
val_and_grad = jax.value_and_grad(logprob_fn) | ||
|
||
def step(key: PRNGKey, state: MarginalState, delta): | ||
y_key, u_key = jax.random.split(key, 2) | ||
|
||
position, logprob, logprob_grad, U_x, U_grad_x = state | ||
|
||
# Update Gamma(delta) | ||
# TODO: Ideally, we could have a dichotomy, where we only update Gamma(delta) if delta changes, | ||
# but this is hardly the most expensive part of the algorithm (the multiplication by U below is). | ||
Gamma_1 = Gamma * delta / (delta + 2 * Gamma) | ||
Gamma_3 = (delta + 2 * Gamma) / (delta + 4 * Gamma) | ||
Gamma_2 = Gamma_1 / Gamma_3 | ||
|
||
# Propose a new y | ||
temp = Gamma_1 * (U_x / (0.5 * delta) + U_grad_x) | ||
temp = temp + jnp.sqrt(Gamma_2) * jax.random.normal(y_key, position.shape) | ||
y = U @ temp | ||
|
||
# Bookkeeping | ||
log_p_y, grad_y = val_and_grad(y) | ||
U_y = U_t @ y | ||
U_grad_y = U_t @ grad_y | ||
|
||
# Acceptance step | ||
temp_x = Gamma_1 * (U_x / (0.5 * delta) + 0.5 * U_grad_x) | ||
temp_y = Gamma_1 * (U_y / (0.5 * delta) + 0.5 * U_grad_y) | ||
|
||
hxy = jnp.dot(U_x - temp_y, Gamma_3 * U_grad_y) | ||
hyx = jnp.dot(U_y - temp_x, Gamma_3 * U_grad_x) | ||
|
||
alpha = jnp.minimum(1, jnp.exp(log_p_y - logprob + hxy - hyx)) | ||
accept = jax.random.uniform(u_key) < alpha | ||
|
||
proposed_state = MarginalState(y, log_p_y, grad_y, U_y, U_grad_y) | ||
state = jax.lax.cond(accept, lambda _: proposed_state, lambda _: state, None) | ||
info = MarginalInfo(alpha, accept, proposed_state) | ||
return state, info | ||
|
||
def init(position): | ||
logprob, logprob_grad = val_and_grad(position) | ||
return MarginalState( | ||
position, logprob, logprob_grad, U_t @ position, U_t @ logprob_grad | ||
) | ||
|
||
return init, step |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.