Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing nested sampling #55

Open
renecotyfanboy opened this issue Aug 11, 2024 · 3 comments
Open

Implementing nested sampling #55

renecotyfanboy opened this issue Aug 11, 2024 · 3 comments

Comments

@renecotyfanboy
Copy link

Hi there,
Thank you very much for putting this package together, this is impressive! I was wondering if you would be interested in an implementation of nested sampling in pure jax. I know that the jaxns package provides an implementation of the Phantom Powered nested sampling algorithm. I think it would be a nice addition to your collection, and there is already a compatibility layer with numpyro.

If you are interested, I can try to draft an implementation of this, even though I would probably wrap the numpyro contributed code instead of working directly with jaxns. WDYT?

@ColCarroll
Copy link
Collaborator

Hi! I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood. I think you could do some really interesting things with bayeux if I (we? you?) figured out some pleasant way of incorporating that into the library. In particular, some of the VI routines from Numpyro, or some of the SMC implementations in blackjax or TFP would become feasible.

I am not sure what that would look like! The input to bayeux is a Callable[PyTree[float], float] (where I'm abusing notation to suggest a PyTree whose leaves are floats). If we call that a LogDensity, I guess a "prior" would need to be some sort of PyTree[LogDensity], but maybe the log densities also need to know how to produce samples...

Anyways, this would need some thought or some design. I'm happy to review a pull request if you have a clear vision, or check out drafts (or you could make your own, more flexible library that supports structured inference!)

@renecotyfanboy
Copy link
Author

Hi, apologies for this late answer, I am a bit busy at the moment...

I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood.

Even if the jaxns model building requires a prior model and a likelihood, the numpyro wrapper uses an identity function as the prior and uses the posterior log-probability as the likelihood, and this seems to do the trick! The code is a bit convoluted tho because of jaxns requiring explicit signatures for the functions.

So for nested sampling, this is not a requirement to factor prior and likelihood from the log-prob function. I'll draft a PR eventually, just be patient, ahah

In particular, some of the VI routines from Numpyro [...] would become feasible.

Which VI routines are not usable in the current situation? Prior distribution is not readily factorable from a numpyro model, I would be curious to see those cases

@ColCarroll
Copy link
Collaborator

No problem! I'm mostly AFK for a week, but wanted to put some thoughts down:

That's interesting that it works well to use an identity function! You certainly understand the situation better than me, and a draft would be welcome - this library has not had a ton of contributors, so I'm sure of the automation may be rocky, but I'm happy to spend some time getting it to work if you can provide a starting point (and perhaps a colab of the function working?)

For VI, I'm not a heavy numpyro user, but I guess I was looking at guide generation (https://num.pyro.ai/en/stable/autoguide.html) and thinking it was a little silly to go from, say, a PyMC model to a PyTree to just guessing that every parameter is a Normal. Maybe it would work better than I expect, though! (in particular, I guess this is mean field VI?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants