Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoLA integration #370

Merged
merged 16 commits into from
Sep 6, 2023
Merged

CoLA integration #370

merged 16 commits into from
Sep 6, 2023

Conversation

daniel-dodd
Copy link
Member

@daniel-dodd daniel-dodd commented Aug 26, 2023

⚠️ NOTE: PYTHON REQUIREMENTS ARE BUMPED TO 3.10 ⚠️

Type of changes

Opening draft PR to get 🪩 rolling on migrating GPJax linops to CoLA 🥤.

⚠️ Minimal work thus far - only refactored to get things working. Expect bugs!

Major outstanding work left to do are (1) API considerations, (2) comprehensive testing.

  • Bug fix
  • [ x] New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • [ x] I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing. (Yes but there might be an issue with coverage).
  • [ x] I've added tests for new code. (Yes but really minimal).
  • [ x] I've added docstrings for the new code.

Description

The gpjax.linops module has been removed. All linear operators in GPJax have been replaced with their analogue in CoLA e.g.,:

  • gpjax.linops.DenseLinearOperator -> cola.ops.Dense (wrapped around cola.PSD where appropriate).
  • gpjax.linops.DiagonalLinearOperator -> cola.ops.Diagonal (wrapped around cola.PSD where appropriate).

With minimal modification to the code, such that tests pass locally.

Plum dispatch is dropped as a direct dependancy and we use singedispatch for citations to avoid clashes with CoLA.

Outstanding issues **[edit:] Have opened issues for these!

  • ⚠️ Bug identified: We should shape test/promote to the correct shape for the likelihood integrations and not just rely on beartype - this is currently dangerous for quadrature!
  • Future PR: API considerations regarding solve and root operations on linear operators.
  • Future PR: Migrate cross_covaraince be a LinearOperator rather than a dense array (this would match the signature of gram). This is beneficial in sparse situations.
  • Future PR: Refactor basis function computation to a low rank LinearOperator in this PR? So that we can do $O(mn^2)$ solves. So that now linear models and Fourier features are now efficient in GPJax.
  • Future PR: Resolve product and sum kernel computations. (We can do this lazily build the gram matrix for kernel, and then apply product and sum operations).

@daniel-dodd daniel-dodd added the enhancement New feature or request label Aug 29, 2023
@daniel-dodd daniel-dodd marked this pull request as ready for review August 29, 2023 12:21
Comment on lines -138 to +140
return self.scale.diagonal()
return cola.diag(self.scale)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check: This probably has edge case behaviour as diag switches between diagonal and (dense) diagonal matrix, while diagonal is strictly to a diagonal array.

Comment on lines +19 to +20
# TODO: Once this functionality is supported in CoLA, remove this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to open issue.

@daniel-dodd
Copy link
Member Author

Hey @mfinzi, would be grateful if you could glance over this integration!

For context its pretty minimal, and only aims to remove our ancient linear operators and revamp them to CoLA. Undoubtedly there may be more efficient things we could do e.g., within our GaussianDistribution object we could be a little more clever with KL divergences or with the quadratic appearing in the density -> any suggestions you may have would be much appreciated.

Nice work on wilson-labs/cola/#33, but perhaps really what I more would like is access to the (lower) Cholesky root itself like I have done in https://github.com/JaxGaussianProcesses/GPJax/blob/2c9ebce5a110b73a54ee9a38f408bd1950912026/gpjax/lower_cholesky.py. I thought the cholesky_decomposed would return a Lazy product $[L, L^{T}]$ so I could extract this term. If something like this could be supported in CoLA then that would be great!

Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks good to me! Very nice :D

@daniel-dodd daniel-dodd merged commit c61ff2f into main Sep 6, 2023
@daniel-dodd daniel-dodd deleted the cola branch October 28, 2023 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants