Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support HEALPix iterations to improve accuacy of transforms #139

Open
jasonmcewen opened this issue Feb 15, 2023 · 1 comment · May be fixed by #241
Open

Support HEALPix iterations to improve accuacy of transforms #139

jasonmcewen opened this issue Feb 15, 2023 · 1 comment · May be fixed by #241

Comments

@jasonmcewen
Copy link
Contributor

No description provided.

@matt-graham
Copy link
Collaborator

Picking up this issue - my understanding is we want to support the equivalent of what is controlled by the iter keyword argument to the healpy.spht.map2alm function which from a digging about a bit in the HEALPix source code eventually gets passed as the num_iter argument to map2alm_iter in the C++ source.

My understanding is this is implementing an instance of iterative refinement, with in HEALpix nomenclature Python pseudocode the implementation being something like

alm = map2alm(map)
for i in range(1, n_iter):
    map_recon = alm2map(alm)
    map_error = map - map_recon
    alm += map2alm(map_error)

or using something more similar to the nomenclature we use in s2fft

flm = forward(f)
for i in range(1, n_iter):
    f_recov = inverse(flm)
    f_error = f- f_recon
    flm += forward(map_error)

with here n_iter = 0 corresponding to the base case of no refinement (single forward pass), and n_iter = 1 with a single iteration of refinement.

Translating this to actual code the following prototype implementation

import s2fft

def forward_iter(f, n_iter=0, **kwargs):
    flm = s2fft.forward(f, **kwargs, iter=0)
    for i in range(n_iter):
        f_recov = s2fft.inverse(flm, **kwargs)
        f_error = f - f_recov
        flm += s2fft.forward(f_error, **kwargs, iter=0)
    return flm

and running with JAX implementations as follows

import jax
jax.config.update("jax_enable_x64", True)
import numpy as np

L = 128
nside = 64
method = "jax"
sampling = "healpix"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

# Compute forward transform with iterative refinement
flm_iter = forward_iter(f, n_iter=10, L=L, method=method, sampling=sampling, nside=nside)

# Compute error in recovered f
f_recov = s2fft.inverse(flm_iter, L, nside=nside, sampling=sampling,  method=method)
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")

gives output

Mean absolute error = 7.556473213859434e-12

while for n_iter = 0 we get corresponding output

Mean absolute error = 0.016750706531930927

So this seems to work as intended.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants