-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the function for generating Gaussian noise #377
Conversation
... that is the same PyTree structure as the input. relate to blackjax-devs#376
Fix formatting.
one more formatting
d4e0dae
to
5e49c8f
Compare
Codecov Report
@@ Coverage Diff @@
## main #377 +/- ##
==========================================
+ Coverage 89.75% 89.79% +0.03%
==========================================
Files 44 45 +1
Lines 2197 2194 -3
==========================================
- Hits 1972 1970 -2
+ Misses 225 224 -1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -29,7 +22,7 @@ def one_step( | |||
) -> SGLDState: | |||
|
|||
step, position, logprob_grad = state | |||
momentum = sample_momentum(rng_key, position, step_size) | |||
momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to somehow keep the sample_momentum
mention here; noise is introduced somewhere else in the SgHMC algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding a code comment is sufficient.
|
||
# TODO(https://github.com/blackjax-devs/blackjax/issues/376) | ||
# Refactor this function to not use ravel_pytree might be more performant. | ||
def generate_gaussian_noise( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can give a better name than gaussian_noise
(which I assume you took from the SgMCMC algorithms)? Like random_normal
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to random_normal
from blackjax.types import Array, PRNGKey, PyTree | ||
|
||
|
||
@partial(jit, static_argnames=("precision",), inline=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename this file pytrees.py
which I find more informative that util
so you would call pytrees.random_normal
for instance ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about pytree_util.py?
... that is the same PyTree structure as the input.
relate to #376