-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch RNG for better memory utilization in dropout layers #5710
Conversation
3c06cd2
to
3c3ef36
Compare
do we want to add a test to check that int8 is used for rng? I think we can check the HLO generated, |
Hey Jack, yea good idea, will add before merging. |
3c3ef36
to
5988f45
Compare
df85d09
to
e26a3ee
Compare
e26a3ee
to
6f3dda6
Compare
Verified locally with the new unit test -- cc @JackCaoG |
* Enable u8 rng-bit-generator with downcast * Use BF16 values if downcast for uniform dist is set.
* Enable u8 rng-bit-generator with downcast * Use BF16 values if downcast for uniform dist is set.
* Enable u8 rng-bit-generator with downcast * Use BF16 values if downcast for uniform dist is set.
* Enable u8 rng-bit-generator with downcast * Use BF16 values if downcast for uniform dist is set.
* Enable u8 rng-bit-generator with downcast * Use BF16 values if downcast for uniform dist is set.
The patch has landed in OpenXLA openxla/xla#6015, we can enable this now, to reduce memory pressure in
nn.Dropout
layer.DO_NOT_MERGE until the next pin update.