-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct healpix_forward
derivatives and add support for forward and higher order autodiff
#244
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #244 +/- ##
==========================================
+ Coverage 96.07% 96.08% +0.01%
==========================================
Files 31 31
Lines 3567 3577 +10
==========================================
+ Hits 3427 3437 +10
Misses 140 140 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting JAX functionality for linear black functions which is good to know. As you say @matt-graham could be subject breaking changes at a later date. At least for now though this looks good to me!
Thanks @matt-graham ! In reponse to your further thoughts (in corresponding order).
|
If jax-ml/jax#24726 is merged this would provide a public API to define both custom VJP and JVPs so would be a better alternative to current approach here, though not specific to linear case.
Yes, agreed, I think the overhead will be minimal here so not worth worrying about. |
Resolves #243
Corrects derivative rule for
healpy_forward
function ins2fft.transforms.c_backend_spherical
and refactors bothhealpy_forward
andhealpy_inverse
in terms of two new custom JAX primitiveshealpy_map2alm
andhealpy_alm2map
that wrap the corresponding healpymap2alm
andalm2map
functions. Thejax.interpreters.ad.deflinear
function is used to mark the primitives as linear operators with custom transposition rules (defined in terms of the primitives themselves). This gives us support for both forward and reverse mode automatic differentiation and also higher order derivatives.The corresponding tests for gradient correctness in
tests/test_spherical_custom_grads.py
are also updated to be stricter, removing the kludge to relaxrtol
tolerance fortest_healpix_c_backend_forward_custom_gradients
toiter
dependent values, testing both forward and reverse mode derivatives and up to second order derivatives, and checking the derivatives of the forward / inverse transforms themselves rather than a scalar function using them (withjax.test_util.check_grads
able to check gradients of vector valued functions).There are a few possible further changes / thoughts:
healpy_map2alm
andhealpy_alm2map
as new public functions in the API. This may be useful in some applications where users want to use harmonic coefficients which are ordered with HEALPix (ring ordered) indexing convention without round tripping back and forth to 2D layout. However it does introduce some duplication in to API and docstrings. We could instead just make these functions private for internal use. Alternatively if we wanted to make them more generally useful we could also make their signatures more closely match the corresponding healpy functions - for example usinglmax
rather thanL
(withlmax = L - 1
, usingmaps
instead off
andalms
intead offlm
for first argument names, not requiringnside
argument forhealpy_alm2map
.healpy_map2alm
fixes theiter
argument tohealpy.map2alm
toiter=0
, with iterative refinement then implemented in a Python loop in thehealpy_forward
function. This has the advantage of not requiring us to manually implement the derivative (transposition) rule forhealpy_map2alm
withiter > 0
, however means we get slightly poorer performance in the forward pass due to using Python rather than C loop.jnp.array
andnp.array
to convert to and from JAX / NumPy arrays, thehealpy_alm2map
andhealpy_map2alm
primitives (and sohealpy_inverse
andhealpy_forward
) are not compatible with thejax.jit
transform (which is also the case in the current implementation). We could instead define a lowering which calls out to the underlying C++ implementations, which would give us JIT support and possibly a small performance boost by avoiding some Python interpreter overhead.jax.interpreters.ad.deflinear
function we use here to define the derivative behaviour / custom transposition rule is not documented. It is used internally in various places and was suggested as a way defining derivatives for linear operators in Best way to define gradient for linear black box function jax-ml/jax#1530 but as it's not part of documented API, it could be subject to breaking changes. We already to an extent rely on some non-documented JAX internals through the custom CUDA primitive added in Cufft primitive #204, so this probably doesn't add much additional maintenance burden, but if we did want to avoid using this undocumented feature we could just usejax.custom_vjp
similar to the current implementation but using the corrected implementation for the derivative rule ofhealpy_forward
, at the expense of not getting forward mode / higher order derivative support.