Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve: Add progress bar to run_inference_algorithm #614

Merged
merged 3 commits into from
Dec 9, 2023

Conversation

PaulScemama
Copy link
Contributor

@PaulScemama PaulScemama commented Dec 9, 2023

Closes #610. Adds progress bar to run_inference_algorithm.

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

Copy link

codecov bot commented Dec 9, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (540db41) 99.18% compared to head (9b4daf7) 99.22%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #614      +/-   ##
==========================================
+ Coverage   99.18%   99.22%   +0.04%     
==========================================
  Files          57       57              
  Lines        2576     2581       +5     
==========================================
+ Hits         2555     2561       +6     
+ Misses         21       20       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a kwarg that default to False, similar to

progress_bar: bool = False,

@PaulScemama
Copy link
Contributor Author

PaulScemama commented Dec 9, 2023

@junpenglao I am trying to figure out why some of the tests are failing -- weird that the details say the test is failing in a file that does not use run_inference_algorithm.

Also, should I adapt some tests to use the progress_bar?

junpenglao
junpenglao previously approved these changes Dec 9, 2023
@junpenglao
Copy link
Member

Looks like there is a flaky test - could you change the rng_key of that test until it pass, and add a comment that it is a flaky test?

@junpenglao
Copy link
Member

@reubenharry found a bug in the function, cause by this line:

initial_state = initial_state_or_position

Could you remove this line and add a test? Something like:

import jax
import jax.numpy as jnp
from blackjax.mcmc.hmc import hmc
from blackjax.util import run_inference_algorithm

def logdensity_fn(x):
    return -0.5 * jnp.sum(jnp.square(x))

alg = hmc(
    logdensity_fn=logdensity_fn,
    inverse_mass_matrix=jnp.eye(2),
    step_size=1.0,
    num_integration_steps=1000,
)

_ = run_inference_algorithm(
    rng_key=jax.random.PRNGKey(0),
    initial_state_or_position=jnp.array([1.0, 1.0]),
    inference_algorithm=alg,
    num_steps=10,
    progress_bar=True)

@PaulScemama
Copy link
Contributor Author

PaulScemama commented Dec 9, 2023

@junpenglao I think you meant to tag me instead of @reubenharry? But yes definitely!, sorry about that. Should I start a new tests file, maybe test_util.py? Or where would you like me to put such a test?

@junpenglao
Copy link
Member

test_util.py sounds good.

@PaulScemama
Copy link
Contributor Author

@junpenglao I have to step away for the rest of the day. Just added a test for run_inference_loop but haven't done anything with the flaky tests yet. I'll be back tomorrow. Thanks for all the help!

@junpenglao
Copy link
Member

The flaky test is not failing now, so I am going to go ahead and merge it. Thanks!

@junpenglao junpenglao merged commit 08e0d75 into blackjax-devs:main Dec 9, 2023
7 checks passed
junpenglao pushed a commit that referenced this pull request Mar 12, 2024
* Add progress bar to run_inference_algorithm

* Add progress_bar as kwarg defaulting to False

* Fix bug in run_inference_algorithm and add test

---------

Co-authored-by: Paul Scemama <pscemama@mitre.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add progress bar to run_inference_loop
3 participants