Skip to content

Commit

Permalink
Add the marginal gradient sampler for latent gaussian models (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos authored and junpenglao committed Mar 12, 2024
1 parent 79a32ae commit 2da32c7
Show file tree
Hide file tree
Showing 8 changed files with 857 additions and 10 deletions.
6 changes: 4 additions & 2 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
elliptical_slice,
hmc,
mala,
mgrad_gaussian,
nuts,
orbital_hmc,
pathfinder_adaptation,
Expand All @@ -18,19 +19,20 @@
__version__ = "0.8.2"

__all__ = [
"dual_averaging",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"rmh",
"elliptical_slice",
"sgld", # stochastic gradient mcmc
"window_adaptation", # mcmc adaptation
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
"ess", # diagnostics
"rhat",
"pathfinder_adaptation",
]
71 changes: 65 additions & 6 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"window_adaptation",
"pathfinder",
"pathfinder_adaptation",
"mgrad_gaussian",
]


Expand Down Expand Up @@ -213,7 +214,6 @@ def __new__( # type: ignore[misc]
integrator: Callable = mcmc.integrators.velocity_verlet,
logprob_grad_fn: Optional[Callable] = None,
) -> SamplingAlgorithm:

step = cls.kernel(integrator, divergence_threshold)

def init_fn(position: PyTree):
Expand Down Expand Up @@ -292,7 +292,6 @@ def __new__( # type: ignore[misc]
logprob_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:

step = cls.kernel()

def init_fn(position: PyTree):
Expand Down Expand Up @@ -375,7 +374,6 @@ def __new__( # type: ignore[misc]
integrator: Callable = mcmc.integrators.velocity_verlet,
logprob_grad_fn: Optional[Callable] = None,
) -> SamplingAlgorithm:

step = cls.kernel(integrator, divergence_threshold, max_num_doublings)

def init_fn(position: PyTree):
Expand All @@ -394,6 +392,70 @@ def step_fn(rng_key: PRNGKey, state):
return SamplingAlgorithm(init_fn, step_fn)


class mgrad_gaussian:
"""Implements the marginal sampler for latent Gaussian model of [1].
It uses a first order approximation to the log_likelihood of a model with Gaussian prior.
Interestingly, the only parameter that needs calibrating is the "step size" delta, which can be done very efficiently.
Calibrating it to have an acceptance rate of roughly 50% is a good starting point.
Examples
--------
A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C) can be initialized and
used for a given "step size" delta with the following code:
.. code::
mgrad_gaussian = blackjax.mgrad_gaussian(f, C, use_inverse=False, mean=m)
state = latent_gaussian.init(zeros) # Starting at the mean of the prior
new_state, info = mgrad_gaussian.step(rng_key, state, delta)
We can JIT-compile the step function for better performance
.. code::
step = jax.jit(latent_gaussian.step)
new_state, info = step(rng_key, state, delta)
Parameters
----------
logprob_fn
The logarithm of the likelihood function for the latent Gaussian model.
covariance
The covariance of the prior Gaussian density.
mean: optional
Mean of the prior Gaussian density. Default is zero.
Returns
-------
A ``SamplingAlgorithm``.
References
----------
[1]: Titsias, M.K. and Papaspiliopoulos, O. (2018), Auxiliary gradient-based sampling algorithms. J. R. Stat. Soc. B, 80: 749-767. https://doi.org/10.1111/rssb.12269
"""

def __new__( # type: ignore[misc]
cls,
logprob_fn: Callable,
covariance: Array,
mean: Optional[Array] = None,
) -> SamplingAlgorithm:
init, step = mcmc.marginal_latent_gaussian.init_and_kernel(
logprob_fn, covariance, mean
)

def init_fn(position: Array):
return init(position)

def step_fn(rng_key: PRNGKey, state, delta: float):
return step(
rng_key,
state,
delta,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]


# -----------------------------------------------------------------------------
# STOCHASTIC GRADIENT MCMC
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -631,7 +693,6 @@ def __new__( # type: ignore[misc]
logprob_fn: Callable,
sigma: Array,
) -> SamplingAlgorithm:

step = cls.kernel()

def init_fn(position: PyTree):
Expand Down Expand Up @@ -696,7 +757,6 @@ def __new__( # type: ignore[misc]
*,
bijection: Callable = mcmc.integrators.velocity_verlet,
) -> SamplingAlgorithm:

step = cls.kernel(bijection)

def init_fn(position: PyTree):
Expand Down Expand Up @@ -749,7 +809,6 @@ def __new__( # type: ignore[misc]
mean: Array,
cov: Array,
) -> SamplingAlgorithm:

step = cls.kernel(cov, mean)

def init_fn(position: PyTree):
Expand Down
20 changes: 18 additions & 2 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
from . import elliptical_slice, hmc, mala, nuts, periodic_orbital, rmh
from . import (
elliptical_slice,
hmc,
mala,
marginal_latent_gaussian,
nuts,
periodic_orbital,
rmh,
)

__all__ = ["elliptical_slice", "hmc", "mala", "nuts", "periodic_orbital", "rmh"]
__all__ = [
"elliptical_slice",
"hmc",
"mala",
"nuts",
"periodic_orbital",
"rmh",
"marginal_latent_gaussian",
]
122 changes: 122 additions & 0 deletions blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Public API for marginal latent Gaussian sampling."""
from typing import NamedTuple

import jax
import jax.numpy as jnp
import jax.scipy.linalg as linalg

from blackjax.types import Array, PRNGKey

__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel"]


class MarginalState(NamedTuple):
"""State of the RMH chain.
x
Current position of the chain.
log_p_x
Current value of the log-likelihood of the model
grad_x
Current value of the gradient of the log-likelihood of the model
Other Attributes:
-----------------
U_x, U_grad_x: Array
Auxiliary attributes
"""

position: Array
logprob: float
logprob_grad: Array

U_x: Array
U_grad_x: Array


class MarginalInfo(NamedTuple):
"""Additional information on the RMH chain.
This additional information can be used for debugging or computing
diagnostics.
acceptance_probability
The acceptance probability of the transition, linked to the energy
difference between the original and the proposed states.
is_accepted
Whether the proposed position was accepted or the original position
was returned.
proposal
The state proposed by the proposal.
"""

acceptance_probability: float
is_accepted: bool
proposal: MarginalState


def init_and_kernel(logprob_fn, covariance, mean=None):
"""Build the marginal version of the auxiliary gradient-based sampler
Returns
-------
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
An init function.
"""

U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True)

if mean is not None:
shift = linalg.solve(covariance, mean, sym_pos=True)
val_and_grad = jax.value_and_grad(lambda x: logprob_fn(x) + jnp.dot(x, shift))
else:
val_and_grad = jax.value_and_grad(logprob_fn)

def step(key: PRNGKey, state: MarginalState, delta):
y_key, u_key = jax.random.split(key, 2)

position, logprob, logprob_grad, U_x, U_grad_x = state

# Update Gamma(delta)
# TODO: Ideally, we could have a dichotomy, where we only update Gamma(delta) if delta changes,
# but this is hardly the most expensive part of the algorithm (the multiplication by U below is).
Gamma_1 = Gamma * delta / (delta + 2 * Gamma)
Gamma_3 = (delta + 2 * Gamma) / (delta + 4 * Gamma)
Gamma_2 = Gamma_1 / Gamma_3

# Propose a new y
temp = Gamma_1 * (U_x / (0.5 * delta) + U_grad_x)
temp = temp + jnp.sqrt(Gamma_2) * jax.random.normal(y_key, position.shape)
y = U @ temp

# Bookkeeping
log_p_y, grad_y = val_and_grad(y)
U_y = U_t @ y
U_grad_y = U_t @ grad_y

# Acceptance step
temp_x = Gamma_1 * (U_x / (0.5 * delta) + 0.5 * U_grad_x)
temp_y = Gamma_1 * (U_y / (0.5 * delta) + 0.5 * U_grad_y)

hxy = jnp.dot(U_x - temp_y, Gamma_3 * U_grad_y)
hyx = jnp.dot(U_y - temp_x, Gamma_3 * U_grad_x)

alpha = jnp.minimum(1, jnp.exp(log_p_y - logprob + hxy - hyx))
accept = jax.random.uniform(u_key) < alpha

proposed_state = MarginalState(y, log_p_y, grad_y, U_y, U_grad_y)
state = jax.lax.cond(accept, lambda _: proposed_state, lambda _: state, None)
info = MarginalInfo(alpha, accept, proposed_state)
return state, info

def init(position):
logprob, logprob_grad = val_and_grad(position)
return MarginalState(
position, logprob, logprob_grad, U_t @ position, U_t @ logprob_grad
)

return init, step
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Examples

examples/Introduction.ipynb
examples/LogisticRegression.ipynb
examples/LogisticRegressionWithLatentGaussianSampler.ipynb
examples/TemperedSMC.ipynb
examples/use_with_numpyro.ipynb
examples/use_with_pymc3.ipynb
Expand Down
Loading

0 comments on commit 2da32c7

Please sign in to comment.