You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@derifatives indicated that tfp.TruncatedNormal.sample wraps jax.random.truncated_normal, here. (We may be misunderstanding when this function is called.)
However, note that jax.random.truncated_normal can be used to sample from the above truncated normal distribution without clear issues:
(I am finding these strange seeming configurations of numbers by running a fairly complex probabilistic inference program I have that is sampling millions of times from TruncatedNormals, and then filtering the results to find where NaNs were generated.)
Example:
returns NaN.
JAX version: 0.4.33. TFP version: 0.23.0.
The text was updated successfully, but these errors were encountered: