-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change the default matmul precision in JAX to highest precision. #7859
Conversation
On CPU and GPU, this change has no effect. On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations. The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g., jax.config.update('jax_default_matmul_precision', 'fastest') #7010 PiperOrigin-RevId: 395549544
Big +1 from the Hugging Face team for this PR. We're constantly running into problems with the default fast-speed low precision TPU default as shown here for example: huggingface/transformers#15754 . The example there shows how quickly hidden states diverge from the original PyTorch implementation of one of the most important speech recognition models (Wav2Vec2) which is very similar to BERT. |
I want to note that this PR has stalled because we haven't been able to come to agreement about what a best default is. No choice makes everyone happy. Would the It can be set globally or as a decorator, e.g.
|
@patrickvonplaten thanks so much for that input. The feedback is really helpful, especially from HuggingFace. |
Thanks a lot for your feedback @hawkinsp and @mattjj ! The To give a bit more background, most of our Transformers users use PyTorch. More and more start looking into Flax/JAX + TPU now though because:
Now IMO the reason because it's not intuitive is because:
|
Also pinging some community contributors here who frequently use JAX in combination with Flax on TPU - to give some feedback: @borisdayma @stefan-it |
I was not aware of matmul precision flag, nor its default and whether it differs from Pytorch. |
Hello! I've been following this discussion off and on and actually ran into this problem again today so I figured I'd add my +1. My muzero implementation kept failing the past couple days and even though I knew about default_matmul_precision and its default being bfloat16, it still took me forever to figure out that this was the problem. It ended up needing tensorfloat32 or higher. I also ran into this problem a while ago training wavenets, which also needed tensorfloat32 or higher. It's common for folks to say that there is no difference between bfloat16 and float32, but in my experiments there has been a huge difference, and my belief that there was not much difference between the two only led to wasted time. My opinion would be to at least set tensorfloat32 as the default, as this seems to be a nice compromise between speed and precision. If someone really wanted to eek more out of jax one way or another, either more speed or more precision, the other options could be found in the docs and implemented. Full float32 wouldn't be a bad choice either, as this would drastically lessen the chances of a problem being caused by the default precision. It may also help people to mention in the docs that there is in fact a difference between the precisions and that they can affect model performance. The TF bfloat16 docs only mention positives and say it is a drop-in replacement for float32, but as I gain experience with TPUs I believe this is simply flat out wrong. |
Closing Copybara created PR due to inactivity |
Change the default matmul precision in JAX to highest precision.
On CPU and GPU, this change has no effect.
On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.
The old behavior can be restored either by passing an explicit Precision option to operators such as
dot
, or by changing the default precision, e.g.,jax.config.update('jax_default_matmul_precision', 'fastest')
#7010