-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support passing jax.random.PRNGKey
inputs in jax
#651
Conversation
if isinstance(seed, jax.Array): | ||
return seed | ||
else: | ||
return draw_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we always convert the seed to a jax PRNGKey in case the backend is JAX? I am just thinking about the advantage of the passing any other seed type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no particular performance advantages (or any other advantages) to passing a PRNGKey over a JAX array. The PRNGKey class is syntactic sugar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. I was referring to the other types (else branch in this case), like a single integer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah the single integers are probably getting broadcast from x -> [x, x]
, which isn't obviously any better or worse than x -> [0, x]
which is what PRNGKey
seems to do in practice, though it might do something different with non-default rng implementations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Please add a unit test for this use case.
added a unit test and made sure it passed |
result will already be an rng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good -- I'll take it from here.
No description provided.