Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mean Field Variational Inference implementation #433

Merged
merged 1 commit into from
Jan 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
irmh,
mala,
meads,
meanfield_vi,
mgrad_gaussian,
nuts,
orbital_hmc,
Expand Down Expand Up @@ -45,7 +46,8 @@
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
"pathfinder", # variational inference
"meanfield_vi", # variational inference
"pathfinder",
"ess", # diagnostics
"rhat",
]
4 changes: 2 additions & 2 deletions blackjax/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -136,7 +135,8 @@ class VIAlgorithm(NamedTuple):

"""

approximate: Callable
init: Callable
step: Callable
sample: Callable


Expand Down
39 changes: 36 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blackjax high-level interface with sampling algorithms."""
from typing import Callable, Dict, NamedTuple, Optional, Union
from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from optax import GradientTransformation

import blackjax.adaptation as adaptation
import blackjax.mcmc as mcmc
Expand Down Expand Up @@ -1251,6 +1252,11 @@ def step_fn(rng_key: PRNGKey, state):
# -----------------------------------------------------------------------------


class PathFinderAlgorithm(NamedTuple):
approximate: Callable
sample: Callable


class pathfinder:
"""Implements the (basic) user interface for the pathfinder kernel.

Expand All @@ -1273,7 +1279,7 @@ class pathfinder:
approximate = staticmethod(vi.pathfinder.approximate)
sample = staticmethod(vi.pathfinder.sample)

def __new__(cls, logdensity_fn: Callable) -> VIAlgorithm: # type: ignore[misc]
def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc]
def approximate_fn(
rng_key: PRNGKey,
position: PyTree,
Expand All @@ -1289,7 +1295,7 @@ def sample_fn(
):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(approximate_fn, sample_fn)
return PathFinderAlgorithm(approximate_fn, sample_fn)


def pathfinder_adaptation(
Expand Down Expand Up @@ -1385,3 +1391,30 @@ def kernel(rng_key, state):
return AdaptationResults(last_chain_state, kernel, parameters)

return AdaptationAlgorithm(run)

xidulu marked this conversation as resolved.
Show resolved Hide resolved

class meanfield_vi:
init = staticmethod(vi.meanfield_vi.init)
step = staticmethod(vi.meanfield_vi.step)
sample = staticmethod(vi.meanfield_vi.sample)

def __new__(
cls,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 100,
): # type: ignore[misc]
def init_fn(position: PyTree):
return cls.init(position, optimizer)

def step_fn(
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState
) -> Tuple[vi.meanfield_vi.MFVIState, vi.meanfield_vi.MFVIInfo]:
return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples)

def sample_fn(
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState, num_samples: int
):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(init_fn, step_fn, sample_fn)
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import pathfinder
from . import meanfield_vi, pathfinder

__all__ = ["pathfinder"]
__all__ = ["pathfinder", "meanfield_vi"]
131 changes: 131 additions & 0 deletions blackjax/vi/meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from optax import GradientTransformation, OptState

from blackjax.types import PRNGKey, PyTree

__all__ = ["MFVIState", "MFVIInfo", "sample", "generate_meanfield_logdensity", "step"]


class MFVIState(NamedTuple):
mu: PyTree
rho: PyTree
opt_state: OptState


class MFVIInfo(NamedTuple):
elbo: float


def init(
position: PyTree,
optimizer: GradientTransformation,
*optimizer_args,
**optimizer_kwargs
) -> MFVIState:
"""Initialize the mean-field VI state."""
mu = jax.tree_map(jnp.zeros_like, position)
rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position)
opt_state = optimizer.init((mu, rho))
return MFVIState(mu, rho, opt_state)


def step(
rng_key: PRNGKey,
state: MFVIState,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 5,
stl_estimator: bool = True,
) -> Tuple[MFVIState, MFVIInfo]:
"""Approximate the target density using the mean-field approximation.

Parameters
----------
rng_key
Key for JAX's pseudo-random number generator.
init_state
Initial state of the mean-field approximation.
logdensity_fn
Function that represents the target log-density to approximate.
optimizer
Optax `GradientTransformation` to be used for optimization.
num_samples
The number of samples that are taken from the approximation
at each step to compute the Kullback-Leibler divergence between
the approximation and the target log-density.
stl_estimator
Whether to use stick-the-landing (STL) gradient estimator [1] for gradient estimation.
The STL estimator has lower gradient variance by removing the score function term
from the gradient. It is suggested by [2] to always keep it in order for better results.

References
----------
.. [1]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017).
Sticking the landing: Simple, lower-variance gradient estimators for variational inference.
Advances in Neural Information Processing Systems, 30.
.. [2]: Agrawal, A., Sheldon, D. R., & Domke, J. (2020).
Advances in black-box VI: Normalizing flows, importance weighting, and optimization.
Advances in Neural Information Processing Systems, 33.
"""

parameters = (state.mu, state.rho)

def kl_divergence_fn(parameters):
mu, rho = parameters
z = _sample(rng_key, mu, rho, num_samples)
if stl_estimator:
mu = jax.lax.stop_gradient(mu)
rho = jax.lax.stop_gradient(rho)
logq = jax.vmap(generate_meanfield_logdensity(mu, rho))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()

elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters)
new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates)
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
return new_state, MFVIInfo(elbo)


def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1):
"""Sample from the mean-field approximation."""
return _sample(rng_key, state.mu, state.rho, num_samples)


def _sample(rng_key, mu, rho, num_samples):
sigma = jax.tree_map(jnp.exp, rho)
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
flatten_sample = (
jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
+ mu_flatten
)
return jax.vmap(unravel_fn)(flatten_sample)


def generate_meanfield_logdensity(mu, rho):
sigma_param = jax.tree_map(jnp.exp, rho)

def meanfield_logdensity(position):
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree_map(jnp.sum, logq_pytree)
return jax.tree_util.tree_reduce(jnp.add, logq)

return meanfield_logdensity
6 changes: 6 additions & 0 deletions docs/vi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ Variational Inference
:nosignatures:

pathfinder
meanfield_vi


Pathfinder
~~~~~~~~~~

.. autoclass:: blackjax.pathfinder

Mean-field VI
~~~~~~~~~~~~~

.. autoclass:: blackjax.meanfield_vi
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.5.5",
"optax",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
1 change: 0 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[pytest]
addopts = -n auto
testpaths= "tests"
filterwarnings =
error
53 changes: 53 additions & 0 deletions tests/test_meanfield_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import optax
from absl.testing import absltest

import blackjax


class MFVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

def test_recover_posterior(self):
ground_truth = [
# loc, scale
(2, 4),
(3, 5),
]

def logdensity_fn(x):
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf(
x["x_2"], *ground_truth[1]
)
return jnp.sum(logpdf)

initial_position = {"x_1": 0.0, "x_2": 0.0}

num_steps = 50_000
num_samples = 500

optimizer = optax.sgd(1e-2)
mfvi = blackjax.meanfield_vi(logdensity_fn, optimizer, num_samples)
state = mfvi.init(initial_position)

rng_key = self.key
for _ in range(num_steps):
rng_key, _ = jax.random.split(rng_key)
state, _ = jax.jit(mfvi.step)(self.key, state)

loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
scale = jax.tree_map(jnp.exp, state.rho)
scale_1, scale_2 = scale["x_1"], scale["x_2"]
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01)
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)
self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01)


if __name__ == "__main__":
absltest.main()