You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm implementing my jax model on TPU and my model involves a step computing the inverse of a matrix. I use the function jnp.linalg.solve() and found that the result is not precise enough. Here is an example to reproduce the issue:
import jax.numpy as jnp
X = jnp.array([[280., 361., 145.],
[361., 657., 247.],
[145., 247., 99.]])
print(X @ jnp.linalg.solve(X, jnp.eye(3)))
which is relatively 'far' from the identity matrix (and not symmetric), making the following steps of my model incorrect.
Given that the matrix to be inversed is neither large nor extreme, I expect its inverse matrix can be computed precisely. (It is precise if executed on CPU). @murphyk
What jax/jaxlib version are you using?
jax v0.3.24
Which accelerator(s) are you using?
TPU
Additional system info
Ubuntu 20.04.5 LTS
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
i.e., the result of the solve is actually pretty good.
By default TPU performs float32 matrix multiplications by truncating the inputs to bfloat16. Something similar happens on modern GPUs (>= Ampere) that perform "TF32" math which is not exactly the same but similar in spirit.
Description
I'm implementing my jax model on TPU and my model involves a step computing the inverse of a matrix. I use the function jnp.linalg.solve() and found that the result is not precise enough. Here is an example to reproduce the issue:
and the output is
which is relatively 'far' from the identity matrix (and not symmetric), making the following steps of my model incorrect.
Given that the matrix to be inversed is neither large nor extreme, I expect its inverse matrix can be computed precisely. (It is precise if executed on CPU).
@murphyk
What jax/jaxlib version are you using?
jax v0.3.24
Which accelerator(s) are you using?
TPU
Additional system info
Ubuntu 20.04.5 LTS
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: