Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct healpix_forward derivatives and add support for forward and higher order autodiff #244

Merged
merged 4 commits into from
Nov 26, 2024

Conversation

matt-graham
Copy link
Collaborator

@matt-graham matt-graham commented Nov 18, 2024

Resolves #243

Corrects derivative rule for healpy_forward function in s2fft.transforms.c_backend_spherical and refactors both healpy_forward and healpy_inverse in terms of two new custom JAX primitives healpy_map2alm and healpy_alm2map that wrap the corresponding healpy map2alm and alm2map functions. The jax.interpreters.ad.deflinear function is used to mark the primitives as linear operators with custom transposition rules (defined in terms of the primitives themselves). This gives us support for both forward and reverse mode automatic differentiation and also higher order derivatives.

The corresponding tests for gradient correctness in tests/test_spherical_custom_grads.py are also updated to be stricter, removing the kludge to relax rtol tolerance for test_healpix_c_backend_forward_custom_gradients to iter dependent values, testing both forward and reverse mode derivatives and up to second order derivatives, and checking the derivatives of the forward / inverse transforms themselves rather than a scalar function using them (with jax.test_util.check_grads able to check gradients of vector valued functions).

There are a few possible further changes / thoughts:

  • Currently this PR exposes healpy_map2alm and healpy_alm2map as new public functions in the API. This may be useful in some applications where users want to use harmonic coefficients which are ordered with HEALPix (ring ordered) indexing convention without round tripping back and forth to 2D layout. However it does introduce some duplication in to API and docstrings. We could instead just make these functions private for internal use. Alternatively if we wanted to make them more generally useful we could also make their signatures more closely match the corresponding healpy functions - for example using lmax rather than L (with lmax = L - 1, using maps instead of f and alms intead of flm for first argument names, not requiring nside argument for healpy_alm2map.
  • At the moment the implementation of healpy_map2alm fixes the iter argument to healpy.map2alm to iter=0, with iterative refinement then implemented in a Python loop in the healpy_forward function. This has the advantage of not requiring us to manually implement the derivative (transposition) rule for healpy_map2alm with iter > 0, however means we get slightly poorer performance in the forward pass due to using Python rather than C loop.
  • Currently as we are still calling out to the healpy HEALPix Python API using jnp.array and np.array to convert to and from JAX / NumPy arrays, the healpy_alm2map and healpy_map2alm primitives (and so healpy_inverse and healpy_forward) are not compatible with the jax.jit transform (which is also the case in the current implementation). We could instead define a lowering which calls out to the underlying C++ implementations, which would give us JIT support and possibly a small performance boost by avoiding some Python interpreter overhead.
  • The jax.interpreters.ad.deflinear function we use here to define the derivative behaviour / custom transposition rule is not documented. It is used internally in various places and was suggested as a way defining derivatives for linear operators in Best way to define gradient for linear black box function jax-ml/jax#1530 but as it's not part of documented API, it could be subject to breaking changes. We already to an extent rely on some non-documented JAX internals through the custom CUDA primitive added in Cufft primitive #204, so this probably doesn't add much additional maintenance burden, but if we did want to avoid using this undocumented feature we could just use jax.custom_vjp similar to the current implementation but using the corrected implementation for the derivative rule of healpy_forward, at the expense of not getting forward mode / higher order derivative support.

Copy link

codecov bot commented Nov 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.08%. Comparing base (909e6f1) to head (fe82eaa).
Report is 9 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #244      +/-   ##
==========================================
+ Coverage   96.07%   96.08%   +0.01%     
==========================================
  Files          31       31              
  Lines        3567     3577      +10     
==========================================
+ Hits         3427     3437      +10     
  Misses        140      140              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@CosmoMatt CosmoMatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting JAX functionality for linear black functions which is good to know. As you say @matt-graham could be subject breaking changes at a later date. At least for now though this looks good to me!

@jasonmcewen
Copy link
Contributor

Thanks @matt-graham !

In reponse to your further thoughts (in corresponding order).

  1. Indeed, good to expose healpy_map2alm and healpy_alm2map to avoid the harmonic coefficient conversion. I think fine to keep to our conventions, e.g. specifying L rather than lmax, so we are consistent throughout s2fft.
  2. This sounds good to get greater accuracy. The performance in speed that we tradeoff for this should be fairly minimal this we're typcailly just considering a handlful of iterations, i.e. typically iter = 3 and we're unlikely to ever have iter > 6 or so.
  3. I think this is fine as is. I would expect these overheads to be minimal compared to the actual spherical harmonic transforms, especially for high L. Do you agree?
  4. This functionality sounds great! If I understand correctly, this gives us access to higher order derivates also. Hopefully they will document the related functionality at some point.

@jasonmcewen jasonmcewen self-requested a review November 26, 2024 14:35
@matt-graham
Copy link
Collaborator Author

matt-graham commented Nov 26, 2024

Interesting JAX functionality for linear black functions which is good to know. As you say @matt-graham could be subject breaking changes at a later date. At least for now though this looks good to me!

  1. This functionality sounds great! If I understand correctly, this gives us access to higher order derivates also. Hopefully they will document the related functionality at some point.

If jax-ml/jax#24726 is merged this would provide a public API to define both custom VJP and JVPs so would be a better alternative to current approach here, though not specific to linear case.

  1. I think this is fine as is. I would expect these overheads to be minimal compared to the actual spherical harmonic transforms, especially for high L. Do you agree?

Yes, agreed, I think the overhead will be minimal here so not worth worrying about.

@matt-graham matt-graham merged commit dc9b2bc into main Nov 26, 2024
8 checks passed
@matt-graham matt-graham deleted the mmg/healpix-gradient-fix branch November 26, 2024 16:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Derivatives for healpix_forward are incorrrect
3 participants