Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the default matmul precision in JAX to highest precision. #7859

Closed
wants to merge 1 commit into from

Commits on Sep 8, 2021

  1. Change the default matmul precision in JAX to highest precision.

    On CPU and GPU, this change has no effect.
    
    On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.
    
    The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g.,
    jax.config.update('jax_default_matmul_precision', 'fastest')
    
    #7010
    
    PiperOrigin-RevId: 395549544
    hawkinsp authored and jax authors committed Sep 8, 2021
    Configuration menu
    Copy the full SHA
    52485d6 View commit details
    Browse the repository at this point in the history