Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default matrix-multiplication precision on TPUs #1856

Closed
shoyer opened this issue Dec 13, 2019 · 2 comments
Closed

Default matrix-multiplication precision on TPUs #1856

shoyer opened this issue Dec 13, 2019 · 2 comments

Comments

@shoyer
Copy link
Collaborator

shoyer commented Dec 13, 2019

Currently, matrix multiplication on TPUs (with float32 dtypes) defaults to bfloat16 multiplication with float32 accumulation. This allows for really fantastic performance for neural nets.

Using higher precision requires explicitly setting it, e.g., jax.numpy.matmul(x, y, precision=jax.lax.Precision.HIGHEST). (We could conceivably shorten this by supporting strings, e.g., jax.numpy.matmul(x, y, precision='highest'))

Is this the right default behavior for functions like jax.numpy.matmul and the @ matmul operator? Or should we switch precision to default to full float32 precision, at the price of extra matrix-multiplication passes? This would probably be a little more annoying for neural network users, but it is arguably less surprising, especially for users who are using matrix-multiplication for other uses.

@shoyer
Copy link
Collaborator Author

shoyer commented Feb 22, 2020

Closing this in favor of #2161 (apparently I'm bad at remembering that I already opened an issue on something!)

@patrickvonplaten
Copy link

Note that the precision can also be globally set with JAX_DEFAULT_MATMUL_PRECISION as explained here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants