-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make the default matmul precision float32 even on TPUs #7010
Comments
I would argue this is also a footgun for some neural net use cases. Both of the two times that I’ve carefully re-implemented a model in JAX, I’ve found that performance is worse than expected with the default precision and it took me some time to realise that precision was the cause. As a (temporary) improvement on the current situation, we could at least add some information on this issue to point 5 of the ‘Current Gotchas’ in the readme. |
#6143 added a config flag, so this should in principle be easier to change now. |
As @juliuskunze just pointed out to me, this is the only difference in semantics (that we can think of) between the GPU, CPU and TPU backends. 'Backend transparency' is a valuable property in mine and Julius's opinion, and given how close JAX is to achieving it (if this really is the only case where semantics differ significantly), it's surely worth changing the TPU default for all matmul-like ops (including conv) to closely approximate GPU and CPU behaviour. I understand this will harm 'default' speed, but I think we can mitigate that by making clear the availability of the bfloat16 option. |
My team shot itself in the foot last week for the ~fourth time due to matmult defaulting to bfloat16. This issue continues to be my biggest/only grievance with JAX. |
Just got another +1 for this issue. Another foot lost. |
On CPU and GPU, this change has no effect. On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations. The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g., jax.config.update('jax_default_matmul_precision', 'fastest') #7010 PiperOrigin-RevId: 395549544
+1 for this issue from the Transformers team for one of the most popular architectures for speech recognition (Wav2Vec2) - see: huggingface/transformers#15754 |
+1 to setting default to f32. Having had the pleasure of shooting both my feet with either performance gotchas and numerics gotchas, I'd wholeheartedly prefer debugging the performance ones as they're much easier to spot. |
Similar discussion in pytorch: pytorch/pytorch#67384. Pytorch once enabled tfloat32 by default for a few ops, and then had to revert the decision due to similar complains. Enabling bfloat16 by default is presumably even worse. |
If anyone here is collecting feet 🦶🔫, I lost one to something related to this too. #12008 (comment) |
What about having a kind of warning/info printed once per process when the low precision is used by default? Printing too much stuff like TF by default isn't great. But I think this one is worth it. |
It does seem a poor choice of user experience that the system does the wrong thing by default, and forces a developer to have to debug its unexpected behaviour. Rather than do the right thing by default and allow the developer an opportunity to feel good about optimising the performance by selecting lower precision maths. An action which they will then find easy to undo should the system perform badly due to the limited precision. |
+1 for this. My benchmark shows that models in low precision do not always get the same performance as high precision, so the correct way (i.e. high precision) should be the default. |
My personal point of view is that this is more complicated than this. When a software target or is used dominantly by one community, you just pick the favorite default. Increasing the knowledge of this issue would be a first good step that I guess people would agree on. |
Solving neural ODEs with diffrax is also affected by the unexpected default choice of TensorFloat32, see here patrick-kidger/diffrax#213. |
+1 another foot lost (kvablack/ddpo-pytorch#3 (comment)), this time in the form of forcing me to update 1/4 of the results for a paper that I already submitted and released. I do think using high-precision by default and allowing users to opt-in for better performance is much easier to debug. |
Follow-up from #2161:
The default low precision is a bit of a footgun, at least when doing anything that isn't implementing a neural net layer. In my opinion, it would be much safer to use "highest" precision by default (which isn't that much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the @ infix operator.
The text was updated successfully, but these errors were encountered: