Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling from TruncatedNormal can yield NaN #1844

Open
georgematheos opened this issue Sep 30, 2024 · 2 comments
Open

Sampling from TruncatedNormal can yield NaN #1844

georgematheos opened this issue Sep 30, 2024 · 2 comments

Comments

@georgematheos
Copy link

Example:

from tensorflow_probability.substrates import jax as tfp
tfp.distributions.TruncatedNormal(
    0.5382424, 0.05, 0.80921564, 0.86921564
).sample(seed=jax.random.PRNGKey(2))

returns NaN.

JAX version: 0.4.33. TFP version: 0.23.0.

@georgematheos
Copy link
Author

georgematheos commented Sep 30, 2024

@derifatives indicated that tfp.TruncatedNormal.sample wraps jax.random.truncated_normal, here. (We may be misunderstanding when this function is called.)

However, note that jax.random.truncated_normal can be used to sample from the above truncated normal distribution without clear issues:

mean, std, minval, maxval = 0.5382424, 0.05, 0.80921564, 0.86921564
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = 0.80921566

@georgematheos
Copy link
Author

However, this does not always work:

mean, std, minval, maxval = 0.09121108, 0.1, 0.62490195, 0.6849019
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = NaN

(I am finding these strange seeming configurations of numbers by running a fairly complex probabilistic inference program I have that is sampling millions of times from TruncatedNormals, and then filtering the results to find where NaNs were generated.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant