Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivatives for healpix_forward are incorrrect #243

Open
matt-graham opened this issue Nov 15, 2024 · 0 comments · May be fixed by #244
Open

Derivatives for healpix_forward are incorrrect #243

matt-graham opened this issue Nov 15, 2024 · 0 comments · May be fixed by #244
Assignees
Labels
bug Something isn't working

Comments

@matt-graham
Copy link
Collaborator

matt-graham commented Nov 15, 2024

I believe the derivatives defined for the s2fft.transforms.c_backend_spherical.healpy_forward function with custom_vjp are incorrect. As a minimal reproducing example:

import s2fft
from s2fft.transforms import c_backend_spherical as c_sph
import numpy as np
import jax
from jax.test_util import check_grads
jax.config.update("jax_enable_x64", True)

L = 32
nside = L // 2
method = "jax"
sampling = "healpix"
reality = True
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L, reality=reality)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method, reality=reality)

Running check_grads on c_sph.healpy_inverse to check gradients against numerical finite differencing completes without any error

check_grads(lambda x: c_sph.healpy_inverse(x, L, nside), (flm,), modes=("rev",), order=1)

Running the same on c_sph.healpy_forward however gives an AssertionError

check_grads(lambda x: c_sph.healpy_forward(x, L, nside, iter=0), (f,), modes=("rev",), order=1)

outputting

AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-05
VJP cotangent projection
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 1.84452973
Max relative difference: 1.69768919
 x: array(-2.931024)
 y: array(-1.086494)

Somewhat confusingly, running check_grads instead on specific scalar functions constructed using c_sph.healpy_forward does pass without error:

check_grads(lambda x: abs(c_sph.healpy_forward(x, L, nside, iter=0)).sum(), (f,), modes=("rev",), order=1)

Notice in all cases iter is fixed to zero so this is not due to using iterative refinement steps. While HEALPix sampling does not exhibit a sampling theorem and so round-tripping through the forward and backward ('inverse') transforms will introduce an error (which iterative refinement in the forward transform can reduce), this only affects the accuracy of the linear operator corresponding to the forward transform being the inverse of the linear operator corresponding to the backward transform.

Distinct from this property, is that the linear operator represented by the forward transform is a (scaled and conjugated) transposition of the linear operator represented by the backward transform (and vice-versa), and this property holds exactly (modulo floating point error) for the HEALPix forward and backward transforms. Importantly it is this transposition property that is required for implementation of the derivative rules and so the incorrectness of the derivatives in the current implementation cannot be explained by the error in the inverse relationship.

Let the matrix represented by HEALPix forward spherical transform (map2alm) be $F \in \mathbb{C}^{m \times n}$ and the matrix represented by HEALPix inverse (backward) spherical transform (alm2map) be $B \in \mathbb{C}^{n \times m}$, with a harmonic bandlimit $\ell$ in both cases and HEALPix resolution parameter $r$ and $n = 12r^2$, $m = \ell(\ell + 1) / 2$. That is

$$\texttt{map2alm}(x) = F x \quad\text{and}\quad \texttt{alm2map}(y) = B y.$$

We can construct the matrices numerically by mapping the standard basis vectors through the map2alm and alm2map functions:

F = np.stack(
    list(
        map(
            lambda e: healpy.map2alm(e, lmax=L - 1),
            np.identity(12 * nside**2, dtype=float),
        )
    ),
    1,
)

# As alm2map argument is complex need to use both real and imaginary basis vectors

B = np.stack(
    list(
        map(
            lambda e: healpy.alm2map(e, nside=nside, lmax=L - 1),
            np.identity(flm_hp.shape[0], dtype=complex),
        )
    ),
    1,
) - 1j * np.stack(
    list(
        map(
            lambda e: healpy.alm2map(e, nside=nside, lmax=L - 1),
            1j * np.identity(L * (L + 1) // 2, dtype=complex),
        )
    ),
    1,
)

We then have the relationships

$$ B = F^H D \quad\text{and}\quad F = D^{-1} B^H$$

where $X^H= \textrm{conj}(X)^T$ (that is the conjugate / Hermitian transpose) and $D$ is a $n \times n$ diagonal matrix with the first $\ell$ entries along the diagonal equal to $3r^2 / \pi$ and the remaining $n - \ell = \ell(\ell - 1) / 2$ diagonal entries equal to $6r^2/\pi$.

Numerical verification:

D = np.diag([1] * L + [2] * (L * (L - 1) // 2)) *  (3 * nside**2) / np.pi
np.allclose(B, F.T.conj() @ D)
D_inv = np.diag(1 / D.diagonal())
np.allclose(D_inv @ B.T.conj(), F)

As map2alm and alm2map are both linear, their Jacobian vector product (JVP) functions are just the original maps:

$$ \mathsf{jvp}(\texttt{map2alm})(x)(v) = \partial(x \mapsto F x) v = F v = \texttt{map2alm}(v) $$ $$ \mathsf{jvp}(\texttt{alm2map})(y)(v) = \partial(y \mapsto B y) v = B v = \texttt{alm2map}(v) $$

For the vector Jacobian product (VJP) functions we have

$$ \begin{aligned} \mathsf{vjp}(\texttt{map2alm})(x)(v) &= v^T \partial(x \mapsto F x) \\ &= v^TF \\ &= (F^T v)^T\\ &= \textrm{conj}((\textrm{conj}(F)^T, \textrm{conj}(v)))^T\\ &= (F^H D D^{-1} ,\textrm{conj}(v))^H\\ &= (B D^{-1} \textrm{conj}(v))^H\\ & = \textrm{conj}(\texttt{alm2map}(D^{-1} \textrm{conj}(v))) \end{aligned} $$

$$ \begin{aligned} \mathsf{vjp}(\texttt{alm2map})(y)(v) &= v^T\partial(y \mapsto B y) \\ &= v^TB \\ &= (B^T v)^T\\ &= \textrm{conj}((\textrm{conj}(B)^T\textrm{conj}(v)))^T\\ &= (D D^{-1} B^H\textrm{conj}(v))^H\\ &= (D F \textrm{conj}(v))^H\\ & = D \textrm{conj}(\texttt{alm2map}(\textrm{conj}(v))) \end{aligned} $$

The current (correct) VJP definition for healpy_inverse is

def _healpy_inverse_bwd(res, f):
"""Private function which implements the backward pass for inverse jax_healpy."""
_, L, nside = res
f_new = f * (12 * nside**2) / (4 * jnp.pi)
flm_out = jnp.array(
np.conj(healpy.map2alm(np.conj(np.array(f_new)), lmax=L - 1, iter=0))
)
# iter MUST be zero otherwise gradient propagation is incorrect (JDM).
flm_out = reindex.flm_hp_to_2d_fast(flm_out, L)
m_conj = (-1) ** (jnp.arange(1, L) % 2)
flm_out = flm_out.at[..., L:].add(
jnp.flip(m_conj * jnp.conj(flm_out[..., : L - 1]), axis=-1)
)
flm_out = flm_out.at[..., : L - 1].set(0)
return flm_out, None, None

It's a little difficult to relate this to the above derivation as this includes both the VJP for healpy.alm2map but also the VJP for s2fft.sampling.reindex.flm_2d_to_hp_fast with healpy_inverse defined as the composition of these:

flm = reindex.flm_2d_to_hp_fast(flm, L)
f = jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
return f

Broadly though we can see the VJP rule corresponds to something which performs $\textrm{conj}(\texttt{alm2map}(\textrm{conj}(v)))$ plus an elementwise scaling operation corresponding to the application of $D$

The current (I think incorrect) VJP definition for healpy_forward is

def _healpy_forward_bwd(res, flm):
"""Private function which implements the backward pass for forward jax_healpy."""
_, L, nside, _ = res
flm_new = reindex.flm_2d_to_hp_fast(flm, L)
f = jnp.array(
np.conj(healpy.alm2map(np.conj(np.array(flm_new)), lmax=L - 1, nside=nside))
)
return f * (4 * jnp.pi) / (12 * nside**2), None, None, None

with healpy_forward here defined as a composition of healpy.map2alm and s2fft.sampling.reindex.flm_hp_to_2d_fast

flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))
return reindex.flm_hp_to_2d_fast(flm, L)

Here we see there is no elementwise scaling before the application of alm2map corresponding to multiplication by $D^{-1}$ in $\textrm{conj}(\texttt{alm2map}(D^{-1} \textrm{conj}(v)))$

@matt-graham matt-graham added the bug Something isn't working label Nov 15, 2024
@matt-graham matt-graham self-assigned this Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant