Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mean Field Variational Inference implementation #433

Merged
merged 1 commit into from
Jan 14, 2023

Conversation

xidulu
Copy link
Contributor

@xidulu xidulu commented Dec 28, 2022

Some initial attempts in integrating MFVI into blackjax #397.

TODO:
There is going to be lots of boiler plate code shared between MFVI and Fullrank VI.

It is also worth considering how to add stick-the-landing gradient estimator [1] and importance-weighted ELBO optimization [2].

[1] Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. "Sticking the landing: Simple, lower-variance gradient estimators for variational inference." Advances in Neural Information Processing Systems 30 (2017).

[2] Domke, Justin, and Daniel R. Sheldon. "Importance weighting and variational inference." Advances in neural information processing systems 31 (2018).

@rlouf rlouf marked this pull request as draft December 28, 2022 09:26
@rlouf rlouf changed the title [Draft] MFVI Add Mean Field Variational Inference implementation Dec 28, 2022
@junpenglao junpenglao self-requested a review December 28, 2022 15:10
blackjax/base.py Outdated Show resolved Hide resolved
return meanfield_logprob


def sample(rng_key, meanfield_param, num_samples: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can call

def generate_gaussian_noise(
internally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.
But at this moment there's no internal function we can call to sample from multivariate Gaussian right (i.e. the fullrank VI case)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the same function for multivariate Gaussian because the linear_map util will dispatch it correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junpenglao But how do you generate num_sample of particles using generate_gaussian_noise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right the util function does not accept that yet - let's keep your version here but could you add a TODO?

# TODO: switch to using `generate_gaussian_noise` in util.py

Copy link
Member

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! That's a very good start, we'll just need to reorganise a few things.

blackjax/vi/mfvi.py Outdated Show resolved Hide resolved


def approximate(
rng, init_params, log_prob_fn, optimizer, sample_size=5, num_steps=200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a default choice for the optimiser; does this assume we're using Optax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usual practice is to use Adam which is available in both Optax and Jaxopt. I will have a hold here till you and @junpenglao decide which optimization library to use for VI.

blackjax/vi/mfvi.py Outdated Show resolved Hide resolved
import jax.scipy.stats as stats
import numpy as np
import blackjax
import optax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think optax is currently listed as a dependency, but jaxopt is. We need to discuss the pros and cons of adding a new dependency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought optax is used in SGMCMC? I think it is ok to add it as dependency.

blackjax/kernels.py Show resolved Hide resolved
tests/test_mfvi.py Outdated Show resolved Hide resolved
blackjax/vi/mfvi.py Outdated Show resolved Hide resolved
tests/test_mfvi.py Outdated Show resolved Hide resolved
blackjax/vi/mfvi.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Dec 29, 2022

I think that's enough comments for now :) Don't be discouraged, this is completely normal as we are also figuring out what the interface should be; there is no template as there can be for MCMC algorithms yet.

@xidulu
Copy link
Contributor Author

xidulu commented Dec 29, 2022

@rlouf @junpenglao Would you mind taking another look at the API?

I think the next thing we need to figure out is to avoid potential boilerplate code since the ELBO computation and optimization process is identical for most of the VI variants.

@rlouf
Copy link
Member

rlouf commented Dec 30, 2022

I think the next thing we need to figure out is to avoid potential boilerplate code since the ELBO computation and optimization process is identical for most of the VI variants.

I prefer to avoid generalising on the first implementation even though it indeed seems like we would need the equivalent of base.py with the SMC algorithm.

I'm taking some time off and won't be able to take a good look at the PR until Jan 3, but will have plenty of time that day to go over it in depth.

@xidulu
Copy link
Contributor Author

xidulu commented Jan 6, 2023

@rlouf Well, apparently I need better documentation and more detailed test cases. Other than that, do you have any further comments on the API design?

Also, I am thinking about moving forward to implementing the full-rank VI, do you think I should open up a new PR, or should I just implement it in this PR.

@rlouf
Copy link
Member

rlouf commented Jan 6, 2023

I am making a pass on your implementation atm; you'll need to pull the changes once I am finished (pull --rebase), but I'm not done quite yet.

I now think your original proposal of implementing VI as a "kernel" was the right way to go, and we will probably need to "kernelize" (we need a better name) Pathfinder as well. I think it's nice that the user has fine-grained control over the optimization.

@rlouf
Copy link
Member

rlouf commented Jan 6, 2023

I rearranged things a little, it was mostly cosmetic. I think the next step is to give it a kernel-like API, tests, docs. If the full-rank is going to share a lot of code it's probably best to do it now rather than go through the hassle of opening a new PR.

@codecov
Copy link

codecov bot commented Jan 6, 2023

Codecov Report

Merging #433 (c69e7f2) into main (27d981e) will decrease coverage by 0.07%.
The diff coverage is 97.22%.

@@            Coverage Diff             @@
##             main     #433      +/-   ##
==========================================
- Coverage   99.25%   99.17%   -0.08%     
==========================================
  Files          47       48       +1     
  Lines        1872     1938      +66     
==========================================
+ Hits         1858     1922      +64     
- Misses         14       16       +2     
Impacted Files Coverage Δ
blackjax/__init__.py 100.00% <ø> (ø)
blackjax/kernels.py 99.20% <94.73%> (-0.38%) ⬇️
blackjax/vi/meanfield_vi.py 97.95% <97.95%> (ø)
blackjax/base.py 100.00% <100.00%> (ø)
blackjax/vi/__init__.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@rlouf rlouf force-pushed the vi_dev branch 2 times, most recently from 068a0b5 to dd6438a Compare January 12, 2023 10:25
@rlouf rlouf marked this pull request as ready for review January 12, 2023 10:25
@rlouf
Copy link
Member

rlouf commented Jan 12, 2023

I turned the algorithm into a "kernel" like you originally suggested, I think it is much more in the general spirit of Blackjax and gives the user more freedom. This is almost ready to merge, we need to expand the docstring of the step function a little more, and add references.

@xidulu
Copy link
Contributor Author

xidulu commented Jan 13, 2023

@rlouf Thanks! I can take care of some documentation jobs! See you TMR in the meeting!

blackjax/base.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Jan 13, 2023

This looks great! I think it is ready to merge after we replace logprob with logdensity 😊

@xidulu
Copy link
Contributor Author

xidulu commented Jan 13, 2023

@rlouf I got it. One second

@xidulu
Copy link
Contributor Author

xidulu commented Jan 13, 2023

@rlouf I just made some changes, is that what you suggest?

@rlouf
Copy link
Member

rlouf commented Jan 13, 2023

I meant logdensity instead of logdensity_fn sorry. Then it'd be great if you could squash your commits. Let me known, and we'll merge if the tests pass... and move on to full rank :)

add some doc

remove extra tuple in base.py

change log prob to logdensity

remove fn from log density
@xidulu
Copy link
Contributor Author

xidulu commented Jan 13, 2023

@rlouf Done

@rlouf
Copy link
Member

rlouf commented Jan 14, 2023

LGTM, thank you for contributing!

@rlouf rlouf merged commit b6807b2 into blackjax-devs:main Jan 14, 2023
@rlouf rlouf linked an issue Jan 14, 2023 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding some basic VI approximation and fitting routine
3 participants