How to do JAX mixed precision properly? #25434
Unanswered
FirstQuadrantSam
asked this question in
Q&A
Replies: 1 comment
-
Not a full answer but this depends on a number of choices / conventions. I looked at Some operations need to be kept in float32, such as attention softmax'ing. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi All!
I am trying to use mixed precision (say, BF16 on V100) to accelerate my training. I did something like this:
and then send the gradient calculated to the optimizers etc. It turns out that this actually gives a slowdown compared to the high-precision version. Note that I used jax.block_until_ready, and only timed the gradient calculation. Then I tried the following:
And it works well with the random data, giving the desired 2x acceleration. Any ideas/suggestions of what could be wrong?
Beta Was this translation helpful? Give feedback.
All reactions