Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding example for custom DataLoader to tutorial 18 #1256

Merged
merged 3 commits into from
Sep 4, 2024

Conversation

psteinb
Copy link
Contributor

@psteinb psteinb commented Sep 3, 2024

What does this implement/fix? Explain your changes

I was happy to see the recent exposure of more control for the training loop. As discussed with @janfb, I added a small example on how rewrite the training loop to exploit such a DataLoader.

Does this close any currently open issues?

I didn't create an issue for this.

Any relevant code examples, logs, error output, etc?

I basically added a section which documents this:

optw = AdamW(list(maf_estimator.parameters()), lr=5e-4)
nepochs = 50

for ep in range(nepochs):
    for idx, (theta_batch, x_batch) in enumerate(train_loader):
        optw.zero_grad()
        losses = maf_estimator.loss(theta_batch, condition=x_batch)
        loss = torch.mean(losses)
        loss.backward()
        optw.step()
    if ep % 10 == 0:
        print("last loss", loss.item())

Any other comments?

Can you please let me know, how to clean a notebook before I commit it? I recall from the hackathon, that there were a couple of steps for this. But maybe I am mistkane.

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • [] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • [ ] New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)
  • For reviewer: The continuous deployment (CD) workflow are passing.

@janfb
Copy link
Contributor

janfb commented Sep 3, 2024

Great, thanks!

Can you make this edit in tutorial and mention the use case with a large dataset, i.e., how this control over the training loop solve the problem with high-D data?
Please also add how you created the maf_estimator using build_maf or so, using only parts of the training data.

Thanks! 🙏

@psteinb
Copy link
Contributor Author

psteinb commented Sep 3, 2024

Yes, will do. So I don't need to clean the notebook before committing it?

- fix import statements
- removed plots
- added section to illustrate how to use a custom DataLoader
@psteinb
Copy link
Contributor Author

psteinb commented Sep 3, 2024

@janfb here is the notebook in question. I removed the plots for the time being. Let me know if I should regenerate them.

@psteinb psteinb marked this pull request as ready for review September 3, 2024 12:49
@psteinb
Copy link
Contributor Author

psteinb commented Sep 3, 2024

ready for review.
@michaeldeistler I fixed some imports that were not matching the 0.23.1 API.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks Peter! Jan, any ideas why this was not caught by our tests that run the notebooks?

@janfb
Copy link
Contributor

janfb commented Sep 3, 2024

I had a quick look and, yikes, there was a testing leftover that restricted the tests to tutorial 05. 🤦‍♂️

I will make a PR.

Thanks for leading us to this @psteinb 😃

@janfb janfb mentioned this pull request Sep 3, 2024
@michaeldeistler
Copy link
Contributor

Let's make a new release after this PR to update the website?

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot @psteinb

Added a couple of comments and suggestions.

tutorials/18_training_interface.ipynb Outdated Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
tutorials/18_training_interface.ipynb Show resolved Hide resolved
Co-authored-by: Jan <janfb@users.noreply.github.com>
@janfb janfb mentioned this pull request Sep 4, 2024
Copy link

codecov bot commented Sep 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.31%. Comparing base (8afd985) to head (1d3a8f5).
Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1256      +/-   ##
==========================================
- Coverage   86.05%   78.31%   -7.75%     
==========================================
  Files         118      119       +1     
  Lines        8672     8697      +25     
==========================================
- Hits         7463     6811     -652     
- Misses       1209     1886     +677     
Flag Coverage Δ
unittests 78.31% <ø> (-7.75%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

see 31 files with indirect coverage changes

@janfb
Copy link
Contributor

janfb commented Sep 4, 2024

Awesome, thanks a lot @psteinb for adding this! 🎉

@janfb janfb merged commit 829817b into sbi-dev:main Sep 4, 2024
6 checks passed
@janfb janfb deleted the updates_tutorial_18 branch September 4, 2024 12:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants