Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial for training interface #983

Merged
merged 5 commits into from
Aug 23, 2024
Merged

Add tutorial for training interface #983

merged 5 commits into from
Aug 23, 2024

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Mar 11, 2024

Below is the training interface for NLE:

from sbi.neural_nets.flow import build_nsf
from sbi.inference.posteriors import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential


# Build neural density estimator.
density_estimator = build_nsf(x, theta)

# Training loop.
opt = Adam(list(density_estimator.parameters()), lr=5e-4)
for _ in range(100):
    opt.zero_grad()
    losses = density_estimator.loss(x, condition=theta)
    loss = torch.mean(losses)
    loss.backward()
    opt.step()

# Build posterior and sample with MCMC.
potential, tf = likelihood_estimator_based_potential(density_estimator, prior, x_o)
posterior = MCMCPosterior(
    potential,
    proposal=prior,
    theta_transform=tf,
    num_chains=100,
    thin=1,
    method="slice_np_vectorized"
)
samples = posterior.sample((1000,), x=x_o)
_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))

Copy link

codecov bot commented Mar 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 76.87%. Comparing base (b3254ed) to head (827df2a).
Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #983      +/-   ##
==========================================
- Coverage   85.44%   76.87%   -8.57%     
==========================================
  Files         101      101              
  Lines        7941     7945       +4     
==========================================
- Hits         6785     6108     -677     
- Misses       1156     1837     +681     
Flag Coverage Δ
unittests 76.87% <ø> (-8.57%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

see 25 files with indirect coverage changes

@Kojobu

This comment was marked as resolved.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed small fixes directly. Thanks for adding this, looks great!

tutorials/02_training_interface.ipynb Outdated Show resolved Hide resolved
@janfb janfb linked an issue Aug 23, 2024 that may be closed by this pull request
@janfb janfb merged commit 67b2038 into main Aug 23, 2024
4 checks passed
@janfb janfb deleted the train branch August 23, 2024 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Training interface
3 participants