-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
define a loop-free untrue batching rule for rng_bit_generator
#20094
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
froystig
force-pushed
the
vmap-rbg
branch
4 times, most recently
from
March 6, 2024 22:43
f89b528
to
2cf3d97
Compare
mattjj
approved these changes
Mar 7, 2024
froystig
force-pushed
the
vmap-rbg
branch
4 times, most recently
from
March 8, 2024 19:02
d30329f
to
b045eff
Compare
This solves my issue with per-example independent sampling (before it was 100X slower compared to dependent sampling), thanks for the fix! |
Merged
ruomingp
added a commit
to ruomingp/axlearn
that referenced
this pull request
May 10, 2024
jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers. The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization.
github-merge-queue bot
pushed a commit
to apple/axlearn
that referenced
this pull request
May 11, 2024
* Upgrades jax from 0.4.25 to 0.4.27. * Fixes attention_test. jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers. The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization. * Fixes rnn_test. * Fixes test_split_prng_key. * Fixes test_split_prng_key. * Fixes test_parent_children. * Upgrades to jax 0.4.28.
devcat722
added a commit
to devcat722/axlearn
that referenced
this pull request
Sep 6, 2024
* Upgrades jax from 0.4.25 to 0.4.27. * Fixes attention_test. jax-ml/jax#20094 changes the behavior of RNG in vmap, so we can no longer rely on identical layer param initialization when using vmap vs. not. This affects RepeatedTransformerLayer and fused QKV layers. The fix is to convert layer params from the reference layer to the test layer instead of relying on identical initialization. * Fixes rnn_test. * Fixes test_split_prng_key. * Fixes test_split_prng_key. * Fixes test_parent_children. * Upgrades to jax 0.4.28.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
fixes #19085, #16792
See #19085 for details.
Note that this is a random-bits-altering change, though only for vmapped random generation.