Skip to content

Commit

Permalink
Add Mean-Field Variational Inference
Browse files Browse the repository at this point in the history
  • Loading branch information
xidulu authored and rlouf committed Jan 12, 2023
1 parent 6159f51 commit 2c38d7a
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 9 deletions.
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
irmh,
mala,
meads,
meanfield_vi,
mgrad_gaussian,
nuts,
orbital_hmc,
Expand Down Expand Up @@ -45,7 +46,8 @@
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
"pathfinder", # variational inference
"meanfield_vi", # variational inference
"pathfinder",
"ess", # diagnostics
"rhat",
]
5 changes: 3 additions & 2 deletions blackjax/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020- The Blackjax Authors.
#
# , Tuple
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -136,7 +136,8 @@ class VIAlgorithm(NamedTuple):
"""

approximate: Callable
init: Callable
step: Callable
sample: Callable


Expand Down
39 changes: 36 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blackjax high-level interface with sampling algorithms."""
from typing import Callable, Dict, NamedTuple, Optional, Union
from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from optax import GradientTransformation

import blackjax.adaptation as adaptation
import blackjax.mcmc as mcmc
Expand Down Expand Up @@ -1251,6 +1252,11 @@ def step_fn(rng_key: PRNGKey, state):
# -----------------------------------------------------------------------------


class PathFinderAlgorithm(NamedTuple):
approximate: Callable
sample: Callable


class pathfinder:
"""Implements the (basic) user interface for the pathfinder kernel.
Expand All @@ -1273,7 +1279,7 @@ class pathfinder:
approximate = staticmethod(vi.pathfinder.approximate)
sample = staticmethod(vi.pathfinder.sample)

def __new__(cls, logdensity_fn: Callable) -> VIAlgorithm: # type: ignore[misc]
def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc]
def approximate_fn(
rng_key: PRNGKey,
position: PyTree,
Expand All @@ -1289,7 +1295,7 @@ def sample_fn(
):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(approximate_fn, sample_fn)
return PathFinderAlgorithm(approximate_fn, sample_fn)


def pathfinder_adaptation(
Expand Down Expand Up @@ -1385,3 +1391,30 @@ def kernel(rng_key, state):
return AdaptationResults(last_chain_state, kernel, parameters)

return AdaptationAlgorithm(run)


class meanfield_vi:
init = staticmethod(vi.meanfield_vi.init)
step = staticmethod(vi.meanfield_vi.step)
sample = staticmethod(vi.meanfield_vi.sample)

def __new__(
cls,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 100,
): # type: ignore[misc]
def init_fn(position: PyTree):
return cls.init(position, optimizer)

def step_fn(
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState
) -> Tuple[vi.meanfield_vi.MFVIState, vi.meanfield_vi.MFVIInfo]:
return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples)

def sample_fn(
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState, num_samples: int
):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(init_fn, step_fn, sample_fn)
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import pathfinder
from . import meanfield_vi, pathfinder

__all__ = ["pathfinder"]
__all__ = ["pathfinder", "meanfield_vi"]
121 changes: 121 additions & 0 deletions blackjax/vi/meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from optax import GradientTransformation, OptState

from blackjax.types import PRNGKey, PyTree

__all__ = ["MFVIState", "MFVIInfo", "sample", "logprob", "step"]


class MFVIState(NamedTuple):
mu: PyTree
rho: PyTree
opt_state: OptState


class MFVIInfo(NamedTuple):
elbo: float


def init(
position: PyTree,
optimizer: GradientTransformation,
*optimizer_args,
**optimizer_kwargs
) -> MFVIState:
"""Initialize the mean-field VI state."""
mu = jax.tree_map(jnp.zeros_like, position)
rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position)
opt_state = optimizer.init((mu, rho))
return MFVIState(mu, rho, opt_state)


def step(
rng_key: PRNGKey,
state: MFVIState,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 5,
stl_estimator: bool = True,
) -> Tuple[MFVIState, MFVIInfo]:
"""Approximate the target density using the mean-field approximation.
Parameters
----------
rng_key
Key for JAX's pseudo-random number generator.
init_state
Initial state of the mean-field approximation.
logdensity_fn
Function that represents the target log-density to approximate.
optimizer
Optax `GradientTransformation` to be used for optimization.
num_samples
The number of samples that are taken from the approximation
at each step to compute the Kullback-Leibler divergence between
the approximation and the target log-density.
TODO: Document `stl_estimator`;
TODO: Add reference(s)
"""

parameters = (state.mu, state.rho)

def kl_divergence_fn(parameters):
mu, rho = parameters
z = _sample(rng_key, mu, rho, num_samples)
if stl_estimator:
mu = jax.lax.stop_gradient(mu)
rho = jax.lax.stop_gradient(rho)
logq = jax.vmap(logprob(mu, rho))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()

elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters)
new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates)
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
return new_state, MFVIInfo(elbo)


def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1):
"""Sample from the mean-field approximation."""
return _sample(rng_key, state.mu, state.rho, num_samples)


def _sample(rng_key, mu, rho, num_samples):
sigma = jax.tree_map(jnp.exp, rho)
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
flatten_sample = (
jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
+ mu_flatten
)
return jax.vmap(unravel_fn)(flatten_sample)


def logprob(mu, rho):
sigma_param = jax.tree_map(jnp.exp, rho)

def meanfield_logprob(position):
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree_map(jnp.sum, logq_pytree)
return jax.tree_util.tree_reduce(jnp.add, logq)

return meanfield_logprob
6 changes: 6 additions & 0 deletions docs/vi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ Variational Inference
:nosignatures:

pathfinder
meanfield_vi


Pathfinder
~~~~~~~~~~

.. autoclass:: blackjax.pathfinder

Mean-field VI
~~~~~~~~~~~~~

.. autoclass:: blackjax.meanfield_vi
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.5.5",
"optax",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
1 change: 0 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[pytest]
addopts = -n auto
testpaths= "tests"
filterwarnings =
error
53 changes: 53 additions & 0 deletions tests/test_meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import optax
from absl.testing import absltest

import blackjax


class MFVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

def test_recover_posterior(self):
ground_truth = [
# loc, scale
(2, 4),
(3, 5),
]

def logdensity_fn(x):
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf(
x["x_2"], *ground_truth[1]
)
return jnp.sum(logpdf)

initial_position = {"x_1": 0.0, "x_2": 0.0}

num_steps = 50_000
num_samples = 500

optimizer = optax.sgd(1e-2)
mfvi = blackjax.meanfield_vi(logdensity_fn, optimizer, num_samples)
state = mfvi.init(initial_position)

rng_key = self.key
for _ in range(num_steps):
rng_key, _ = jax.random.split(rng_key)
state, _ = jax.jit(mfvi.step)(self.key, state)

loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
scale = jax.tree_map(jnp.exp, state.rho)
scale_1, scale_2 = scale["x_1"], scale["x_2"]
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01)
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)
self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01)


if __name__ == "__main__":
absltest.main()

0 comments on commit 2c38d7a

Please sign in to comment.