Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the function for generating Gaussian noise #377

Merged
merged 5 commits into from
Oct 18, 2022

Conversation

junpenglao
Copy link
Member

@junpenglao junpenglao commented Oct 10, 2022

... that is the same PyTree structure as the input.

relate to #376

@junpenglao junpenglao marked this pull request as ready for review October 11, 2022 05:07
... that is the same PyTree structure as the input.

relate to blackjax-devs#376
Fix formatting.
one more formatting
@codecov
Copy link

codecov bot commented Oct 13, 2022

Codecov Report

Merging #377 (5e49c8f) into main (f0a3922) will increase coverage by 0.03%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #377      +/-   ##
==========================================
+ Coverage   89.75%   89.79%   +0.03%     
==========================================
  Files          44       45       +1     
  Lines        2197     2194       -3     
==========================================
- Hits         1972     1970       -2     
+ Misses        225      224       -1     
Impacted Files Coverage Δ
blackjax/adaptation/window_adaptation.py 100.00% <100.00%> (ø)
blackjax/mcmc/diffusion.py 100.00% <100.00%> (ø)
blackjax/mcmc/elliptical_slice.py 94.91% <100.00%> (+1.36%) ⬆️
blackjax/mcmc/ghmc.py 100.00% <100.00%> (ø)
blackjax/mcmc/mala.py 100.00% <100.00%> (ø)
blackjax/mcmc/metrics.py 100.00% <100.00%> (ø)
blackjax/mcmc/rmh.py 96.00% <100.00%> (-0.56%) ⬇️
blackjax/sgmcmc/diffusion.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sghmc.py 100.00% <100.00%> (ø)
blackjax/util.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@junpenglao junpenglao requested a review from rlouf October 18, 2022 13:02
@rlouf rlouf enabled auto-merge (squash) October 18, 2022 14:04
Copy link
Member

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -29,7 +22,7 @@ def one_step(
) -> SGLDState:

step, position, logprob_grad = state
momentum = sample_momentum(rng_key, position, step_size)
momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to somehow keep the sample_momentum mention here; noise is introduced somewhere else in the SgHMC algorithm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a code comment is sufficient.


# TODO(https://github.com/blackjax-devs/blackjax/issues/376)
# Refactor this function to not use ravel_pytree might be more performant.
def generate_gaussian_noise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can give a better name than gaussian_noise (which I assume you took from the SgMCMC algorithms)? Like random_normal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to random_normal

from blackjax.types import Array, PRNGKey, PyTree


@partial(jit, static_argnames=("precision",), inline=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this file pytrees.py which I find more informative that util so you would call pytrees.random_normal for instance ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about pytree_util.py?

@rlouf rlouf merged commit 21ab5c0 into blackjax-devs:main Oct 18, 2022
@junpenglao junpenglao deleted the flatten_refactor branch June 16, 2023 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants