Skip to content

Commit

Permalink
Add Mean-Field Variational Inference
Browse files Browse the repository at this point in the history
  • Loading branch information
xidulu authored and rlouf committed Jan 6, 2023
1 parent 3649ef5 commit 7760b18
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 3 deletions.
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
irmh,
mala,
meads,
meanfield_vi,
mgrad_gaussian,
nuts,
orbital_hmc,
Expand Down Expand Up @@ -45,7 +46,8 @@
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
"pathfinder", # variational inference
"meanfield_vi", # variational inference
"pathfinder",
"ess", # diagnostics
"rhat",
]
25 changes: 25 additions & 0 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
from optax import GradientTransformation

import blackjax.adaptation as adaptation
import blackjax.mcmc as mcmc
Expand Down Expand Up @@ -1385,3 +1386,27 @@ def kernel(rng_key, state):
return AdaptationResults(last_chain_state, kernel, parameters)

return AdaptationAlgorithm(run)


class meanfield_vi:
approximate = staticmethod(vi.meanfield_vi.approximate)
sample = staticmethod(vi.meanfield_vi.sample)
init = staticmethod(vi.meanfield_vi.init)

def __new__(cls, logprob_fn: Callable, optimizer: GradientTransformation) -> VIAlgorithm: # type: ignore[misc]
def approximate_fn(
rng_key: PRNGKey,
position: PyTree,
sample_size: int = 200,
num_steps: int = 1000,
):
return cls.approximate(
rng_key, position, logprob_fn, optimizer, sample_size, num_steps
)

def sample_fn(
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState, num_samples: int
):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(approximate_fn, sample_fn)
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import pathfinder
from . import meanfield_vi, pathfinder

__all__ = ["pathfinder"]
__all__ = ["pathfinder", "meanfield_vi"]
131 changes: 131 additions & 0 deletions blackjax/vi/meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from optax import GradientTransformation, OptState

from blackjax.types import Array, PRNGKey, PyTree

__all__ = ["MFVIState", "MFVIInfo", "sample", "logprob", "approximate"]


class MFVIState(NamedTuple):
mu: PyTree
rho: PyTree
opt_state: OptState


class MFVIInfo(NamedTuple):
mu_trace: Array
rho_trace: Array
elbo_trace: Array


def init(initial_position: PyTree, optimizer: GradientTransformation) -> MFVIState:
"""Initialize the mean-field VI state."""
mu = jax.tree_map(jnp.zeros_like, initial_position)
rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), initial_position)
opt_state = optimizer.init((mu, rho))
return MFVIState(mu, rho, opt_state)


def approximate(
rng_key: PRNGKey,
init_state: MFVIState,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 5,
num_steps: int = 200,
stl_estimator=True,
) -> Tuple[MFVIState, MFVIInfo]:
"""Approximate the target density using the mean-field approximation.
Parameters
----------
rng_key
Key for JAX's pseudo-random number generator.
init_state
Initial state of the mean-field approximation.
logdensity_fn
Function that represents the target log-density to approximate.
optimizer
Optax `GradientTransformation` to be used for optimization.
num_samples
The number of samples that are taken from the approximation
at each step to compute the Kullback-Leibler divergence between
the approximation and the target log-density.
num_steps
The number of optimization steps to perform.
TODO: I think we should actually implement this using the Kernel API;
TODO: Document `stl_estimator`;
TODO: Add reference(s)
"""

def meanfield_approximate_step(state: MFVIState, rng_key: PRNGKey):
parameters = (state.mu, state.rho)

def kl_divergence_fn(parameters):
mu, rho = parameters
z = _sample(rng_key, mu, rho, num_samples)
if stl_estimator:
mu = jax.lax.stop_gradient(mu)
rho = jax.lax.stop_gradient(rho)
logq = jax.vmap(logprob(mu, rho))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()

elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
updates, new_opt_state = optimizer.update(
elbo_grad, state.opt_state, parameters
)
new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates)
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
return new_state, (new_state, elbo)

keys = jax.random.split(rng_key, num_steps)
last_state, (states, elbo_values) = jax.lax.scan(
meanfield_approximate_step, init_state, keys
)
return last_state, MFVIInfo(states.mu, states.rho, elbo_values)


def logprob(mu, rho):
sigma_param = jax.tree_map(jnp.exp, rho)

def meanfield_logprob(position):
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree_map(jnp.sum, logq_pytree)
return jax.tree_util.tree_reduce(jnp.add, logq)

return meanfield_logprob


def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1):
"""Sample from the mean-field approximation."""
return _sample(rng_key, state.mu, state.rho, num_samples)


def _sample(rng_key, mu, rho, num_samples):
sigma = jax.tree_map(jnp.exp, rho)
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
flatten_sample = (
jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
+ mu_flatten
)
return jax.vmap(unravel_fn)(flatten_sample)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.5.5",
"optax",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
47 changes: 47 additions & 0 deletions tests/test_meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import optax
from absl.testing import absltest

import blackjax


class MFVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

def test_recover_posterior(self):
ground_truth = [
# loc, scale
(2, 4),
(3, 5),
]

def logdensity_fn(x):
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf(
x["x_2"], *ground_truth[1]
)
return jnp.sum(logpdf)

initial_position = {"x_1": 0.0, "x_2": 0.0}
optimizer = optax.sgd(1e-2)
init_params = blackjax.kernels.meanfield_vi.init(initial_position, optimizer)

kernel = blackjax.meanfield_vi(logdensity_fn, optimizer)
state, _ = kernel.approximate(
self.key, init_params, num_steps=50000, sample_size=500
)
loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
scale = jax.tree_map(jnp.exp, state.rho)
scale_1, scale_2 = scale["x_1"], scale["x_2"]
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01)
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)
self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01)


if __name__ == "__main__":
absltest.main()

0 comments on commit 7760b18

Please sign in to comment.