You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, matrix multiplication on TPUs (with float32 dtypes) defaults to bfloat16 multiplication with float32 accumulation. This allows for really fantastic performance for neural nets.
Using higher precision requires explicitly setting it, e.g., jax.numpy.matmul(x, y, precision=jax.lax.Precision.HIGHEST). (We could conceivably shorten this by supporting strings, e.g., jax.numpy.matmul(x, y, precision='highest'))
Is this the right default behavior for functions like jax.numpy.matmul and the @ matmul operator? Or should we switch precision to default to full float32 precision, at the price of extra matrix-multiplication passes? This would probably be a little more annoying for neural network users, but it is arguably less surprising, especially for users who are using matrix-multiplication for other uses.
The text was updated successfully, but these errors were encountered:
Currently, matrix multiplication on TPUs (with float32 dtypes) defaults to bfloat16 multiplication with float32 accumulation. This allows for really fantastic performance for neural nets.
Using higher precision requires explicitly setting it, e.g.,
jax.numpy.matmul(x, y, precision=jax.lax.Precision.HIGHEST)
. (We could conceivably shorten this by supporting strings, e.g.,jax.numpy.matmul(x, y, precision='highest')
)Is this the right default behavior for functions like
jax.numpy.matmul
and the@
matmul operator? Or should we switchprecision
to default to full float32 precision, at the price of extra matrix-multiplication passes? This would probably be a little more annoying for neural network users, but it is arguably less surprising, especially for users who are using matrix-multiplication for other uses.The text was updated successfully, but these errors were encountered: