Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support passing jax.random.PRNGKey inputs in jax #651

Merged
merged 4 commits into from
Aug 2, 2023

Conversation

GallagherCommaJack
Copy link
Contributor

No description provided.

if isinstance(seed, jax.Array):
return seed
else:
return draw_seed(seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we always convert the seed to a jax PRNGKey in case the backend is JAX? I am just thinking about the advantage of the passing any other seed type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no particular performance advantages (or any other advantages) to passing a PRNGKey over a JAX array. The PRNGKey class is syntactic sugar.

Copy link
Collaborator

@AakashKumarNain AakashKumarNain Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I was referring to the other types (else branch in this case), like a single integer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah the single integers are probably getting broadcast from x -> [x, x], which isn't obviously any better or worse than x -> [0, x] which is what PRNGKey seems to do in practice, though it might do something different with non-default rng implementations?

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please add a unit test for this use case.

@GallagherCommaJack
Copy link
Contributor Author

added a unit test and made sure it passed

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good -- I'll take it from here.

@fchollet fchollet merged commit 701aea8 into keras-team:main Aug 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants