Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column-pivoted QR via geqp3 (CPU lowering only) #20282

Merged
merged 1 commit into from
Jan 10, 2025

Conversation

tttc3
Copy link
Contributor

@tttc3 tttc3 commented Mar 16, 2024

Provides a CPU only lowering of the column-pivoted QR decomposition in lapack #12897. I would like to add GPU lowerings via MAGMA, but I am unsure on the best way to optionally include this dependancy (as suggested here #1259 (comment) ) into the jaxlib build.

@tttc3
Copy link
Contributor Author

tttc3 commented Dec 3, 2024

See ccb3317 for a reference implementation of a primitive that uses MAGMA.

@tttc3
Copy link
Contributor Author

tttc3 commented Dec 18, 2024

@dfm, would you be able to review this, and should I open a new draft PR for discussing a MAGMA based GPU implementation?

@dfm dfm self-assigned this Dec 18, 2024
Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, and sorry for the slow review! I wasn't actually working on JAX when you initially submitted the PR, and then didn't see it until you pinged me, so that was very useful :D

This looks great overall. I've added an initial round of small inline comments. Please take a look at those, and then we can fire up the full CI run.

I think it's probably worth getting the details of the CPU implementation and the frontend API sorted out before getting into the MAGMA implementation too deeply, but I'm very happy to review that PR once we get there too. Thanks again!!

jax/_src/lax/linalg.py Outdated Show resolved Hide resolved
else:
q = operand
r = operand
return q, r
p = operand.update(dtyp=np.dtype(np.int32))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work in general (update won't be defined if the input isn't a ShapedArray). The standard approach (e.g. see the abstract eval rule for LU) seems to be to just return the operand for p as well, even though that's nonsense. I wonder if @hawkinsp, @mattjj, or anyone else can chime in about what semantics we would want for an abstract eval called with non-ShapedArray inputs?

jax/_src/lax/linalg.py Outdated Show resolved Hide resolved
jaxlib/cpu/lapack_kernels.cc Outdated Show resolved Hide resolved
jaxlib/cpu/lapack_kernels.cc Show resolved Hide resolved
tests/linalg_test.py Show resolved Hide resolved
@tttc3
Copy link
Contributor Author

tttc3 commented Jan 2, 2025

Thanks for the review! In the interest of reducing code duplication, would it be preferred if PivotingQrFactorizationComplex was removed, and if constexpr used to handle the complex case (where the only difference is to allocate and pass rwork)? I.E. something similar to what is done for the SVD.

@tttc3 tttc3 requested a review from dfm January 3, 2025 15:10
Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving this PR to trigger the full test suite, but I think we still want to iterate on the JVP question!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 3, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 3, 2025

In the interest of reducing code duplication, would it be preferred if PivotingQrFactorizationComplex was removed, and if constexpr used to handle the complex case (where the only difference is to allocate and pass rwork)?

I don't have a strong opinion either way! Please feel free to use whichever approach you prefer.

@tttc3 tttc3 requested a review from dfm January 6, 2025 14:21
@dfm
Copy link
Collaborator

dfm commented Jan 6, 2025

Thanks! This all looks good to me now! Can you rebase this PR onto the main branch, squashing your commits into one? Then I'll work on getting it merged. Thanks again!

@dfm
Copy link
Collaborator

dfm commented Jan 6, 2025

PS. Please ping me once you've rebased, I don't get notified for pushes!

@tttc3 tttc3 force-pushed the pivoted-qr branch 2 times, most recently from 50c5766 to 8b32742 Compare January 6, 2025 17:05
@tttc3
Copy link
Contributor Author

tttc3 commented Jan 6, 2025

@dfm I've rebased and re-added the type hint overloads for jax.scipy.linalg.qr.

@dfm
Copy link
Collaborator

dfm commented Jan 9, 2025

Thanks! It looks like this causes a couple of failures for tests that aren't run as part of the GitHub CI. Can you add the new geqp3 primitive to this list:

tf_not_yet_impl = [

And then only run the new JVP tests when pivoting is True?

After you make those changes, please squash the commits and rebase onto the current main. Then we should be good to go!

Thanks again for this!!

return q @ r

m, n = shape
full_matrices = mode in ["full", "r"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still seeing failures here when mode == "r". Perhaps we should just test for mode == "full"? Regardless, we might also want to "undo" the pivoting in qr_and_mul so that any changes in the pivoting caused by the finite difference used in the check doesn't lead to unexpected results. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mode=="full" case should be sufficient on its own, but we can easily include the mode=="economic" case as well so I'm happy to just ignore mode=="r" (this is implicitly tested in the other modes anyway). I'm not sure about the pivoting, but I think you are probably right - I will think about it and let you know later.

Copy link
Contributor Author

@tttc3 tttc3 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated qr_and_mul to "undo" the pivoting - it now expresses the identity function, like in the numpy tests (this should be more robust).

A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb3317 (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
@tttc3 tttc3 requested a review from dfm January 9, 2025 20:48
@copybara-service copybara-service bot merged commit 564b6b0 into jax-ml:main Jan 10, 2025
20 of 24 checks passed
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 17, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
@tttc3 tttc3 deleted the pivoted-qr branch January 17, 2025 17:02
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 17, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 17, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 17, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 18, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
tttc3 added a commit to tttc3/jax that referenced this pull request Jan 18, 2025
Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants