Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the default matmul precision in JAX to highest precision. #7859

Closed
wants to merge 1 commit into from

Conversation

copybara-service[bot]
Copy link

Change the default matmul precision in JAX to highest precision.

On CPU and GPU, this change has no effect.

On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.

The old behavior can be restored either by passing an explicit Precision option to operators such as dot, or by changing the default precision, e.g.,
jax.config.update('jax_default_matmul_precision', 'fastest')

#7010

On CPU and GPU, this change has no effect.

On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.

The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g.,
jax.config.update('jax_default_matmul_precision', 'fastest')

#7010

PiperOrigin-RevId: 395549544
@patrickvonplaten
Copy link

patrickvonplaten commented Feb 22, 2022

Big +1 from the Hugging Face team for this PR. We're constantly running into problems with the default fast-speed low precision TPU default as shown here for example: huggingface/transformers#15754 . The example there shows how quickly hidden states diverge from the original PyTorch implementation of one of the most important speech recognition models (Wav2Vec2) which is very similar to BERT.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 22, 2022

I want to note that this PR has stalled because we haven't been able to come to agreement about what a best default is. No choice makes everyone happy.

Would the default_matmul_precision config option solve your problem?

It can be set globally or as a decorator, e.g.

@jax.default_matmul_precision("float32")
def f(...): ...

@mattjj
Copy link
Collaborator

mattjj commented Feb 22, 2022

@patrickvonplaten thanks so much for that input. The feedback is really helpful, especially from HuggingFace.

@patrickvonplaten
Copy link

Thanks a lot for your feedback @hawkinsp and @mattjj ! The default_matmul_precision: JAX_DEFAULT_MATMUL_PRECISION=float32 global flag as @mattjj nicely explains here nicely solves the problem, but you have to know that it exists and you have to know that the default precision is bf16 for TPU. For our community and even people within Hugging Face, this is not intuitive really.

To give a bit more background, most of our Transformers users use PyTorch. More and more start looking into Flax/JAX + TPU now though because:

  • performance
  • Compared to PyTorch, there are arguably less gnarly bugs, the API is more intuitive for multi-TPU or multi-GPU settings of its performance and especially because multi-device settings and model/data parallelism works very well in is very intuitive (arguably much more than it is for PyTorch)

Now IMO the reason because it's not intuitive is because:

  1. People come from PyTorch using GPU. They associate Flax/JAX with TPU and kind of transfer there PyTorch/GPU knowledge to the Flax/JAX TPU use case. So training scripts are run and the gradients explodes, loss is unstable -> if you don't know the difference between GPU/TPU and float32/bf16/fp16 it's not very straight-forward to find answers why the loss is unstable. I do think lots of people start using TPU expect that the precision is as high as possible and the code is optimized only when defined so by the user. E.g. bfloat16 is not the default in PyTorch GPU, people have to define this themselves.

  2. The global flag JAX_DEFAULT_MATMUL_PRECISION is great, but it's pretty hard to find out about this, e.g. when googling. For comparison if I google "set GPU device for PyTorch", it's trivial to find the CUDA_VISIBLE_DEVICES flag. However if I google "Set precision for JAX TPU", I'm not finding the "JAX_DEFAULT_MATMUL_PRECISION" flag at all. I quickly see this issue and https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm, but no where is this flag mentioned. I think there are people who at some point come to the conclusion that they have to specif icy the precision manually at every https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html# module because they just don't know that this Flag exists. So somehow making this flag more prominent in the docs/issues is really important I think.

@patrickvonplaten
Copy link

Also pinging some community contributors here who frequently use JAX in combination with Flax on TPU - to give some feedback: @borisdayma @stefan-it

@borisdayma
Copy link

borisdayma commented Feb 23, 2022

I was not aware of matmul precision flag, nor its default and whether it differs from Pytorch.
I've been using the dtype argument from flax.linen modules, not sure if it's related…

@evanatyourservice
Copy link

Hello! I've been following this discussion off and on and actually ran into this problem again today so I figured I'd add my +1.

My muzero implementation kept failing the past couple days and even though I knew about default_matmul_precision and its default being bfloat16, it still took me forever to figure out that this was the problem. It ended up needing tensorfloat32 or higher. I also ran into this problem a while ago training wavenets, which also needed tensorfloat32 or higher. It's common for folks to say that there is no difference between bfloat16 and float32, but in my experiments there has been a huge difference, and my belief that there was not much difference between the two only led to wasted time.

My opinion would be to at least set tensorfloat32 as the default, as this seems to be a nice compromise between speed and precision. If someone really wanted to eek more out of jax one way or another, either more speed or more precision, the other options could be found in the docs and implemented. Full float32 wouldn't be a bad choice either, as this would drastically lessen the chances of a problem being caused by the default precision. It may also help people to mention in the docs that there is in fact a difference between the precisions and that they can affect model performance. The TF bfloat16 docs only mention positives and say it is a drop-in replacement for float32, but as I gain experience with TPUs I believe this is simply flat out wrong.

@MichaelHudgins
Copy link
Collaborator

Closing Copybara created PR due to inactivity

@MichaelHudgins MichaelHudgins deleted the test_395549544 branch May 9, 2024 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants