-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement column-pivoted QR via geqp3 (CPU lowering only) #20282
Conversation
See ccb3317 for a reference implementation of a primitive that uses MAGMA. |
@dfm, would you be able to review this, and should I open a new draft PR for discussing a MAGMA based GPU implementation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this, and sorry for the slow review! I wasn't actually working on JAX when you initially submitted the PR, and then didn't see it until you pinged me, so that was very useful :D
This looks great overall. I've added an initial round of small inline comments. Please take a look at those, and then we can fire up the full CI run.
I think it's probably worth getting the details of the CPU implementation and the frontend API sorted out before getting into the MAGMA implementation too deeply, but I'm very happy to review that PR once we get there too. Thanks again!!
jax/_src/lax/linalg.py
Outdated
else: | ||
q = operand | ||
r = operand | ||
return q, r | ||
p = operand.update(dtyp=np.dtype(np.int32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work in general (update
won't be defined if the input isn't a ShapedArray
). The standard approach (e.g. see the abstract eval rule for LU) seems to be to just return the operand for p
as well, even though that's nonsense. I wonder if @hawkinsp, @mattjj, or anyone else can chime in about what semantics we would want for an abstract eval called with non-ShapedArray
inputs?
Thanks for the review! In the interest of reducing code duplication, would it be preferred if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving this PR to trigger the full test suite, but I think we still want to iterate on the JVP question!
I don't have a strong opinion either way! Please feel free to use whichever approach you prefer. |
Thanks! This all looks good to me now! Can you rebase this PR onto the |
PS. Please ping me once you've rebased, I don't get notified for pushes! |
50c5766
to
8b32742
Compare
@dfm I've rebased and re-added the type hint overloads for |
Thanks! It looks like this causes a couple of failures for tests that aren't run as part of the GitHub CI. Can you add the new jax/jax/experimental/jax2tf/jax2tf.py Line 1477 in 640cb00
And then only run the new JVP tests when After you make those changes, please squash the commits and rebase onto the current main. Then we should be good to go! Thanks again for this!! |
tests/linalg_test.py
Outdated
return q @ r | ||
|
||
m, n = shape | ||
full_matrices = mode in ["full", "r"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still seeing failures here when mode == "r"
. Perhaps we should just test for mode == "full"
? Regardless, we might also want to "undo" the pivoting in qr_and_mul
so that any changes in the pivoting caused by the finite difference used in the check doesn't lead to unexpected results. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mode=="full"
case should be sufficient on its own, but we can easily include the mode=="economic"
case as well so I'm happy to just ignore mode=="r"
(this is implicitly tested in the other modes anyway). I'm not sure about the pivoting, but I think you are probably right - I will think about it and let you know later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated qr_and_mul
to "undo" the pivoting - it now expresses the identity function, like in the numpy tests (this should be more robust).
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks to the `geqp3` routine of LAPACK. To provide the same functionality in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK routine via the FFI on CPU devices. Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the use of column-pivoting on CPU devices. To provide a GPU implementation of `geqp3` may require using MAGMA, due to the lack of a `geqp3` implementation in `cuSolver` - see ccb3317 (`jax.lax.linalg.eig`) for an example of using MAGMA in GPU lowerings. Such a GPU implementation can be considered in the future.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
Provides a CPU only lowering of the column-pivoted QR decomposition in lapack #12897. I would like to add GPU lowerings via MAGMA, but I am unsure on the best way to optionally include this dependancy (as suggested here #1259 (comment) ) into the jaxlib build.