-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support HEALPix iterations to improve accuacy of transforms #139
Comments
Picking up this issue - my understanding is we want to support the equivalent of what is controlled by the My understanding is this is implementing an instance of iterative refinement, with in HEALpix nomenclature Python pseudocode the implementation being something like alm = map2alm(map)
for i in range(1, n_iter):
map_recon = alm2map(alm)
map_error = map - map_recon
alm += map2alm(map_error) or using something more similar to the nomenclature we use in flm = forward(f)
for i in range(1, n_iter):
f_recov = inverse(flm)
f_error = f- f_recon
flm += forward(map_error) with here Translating this to actual code the following prototype implementation import s2fft
def forward_iter(f, n_iter=0, **kwargs):
flm = s2fft.forward(f, **kwargs, iter=0)
for i in range(n_iter):
f_recov = s2fft.inverse(flm, **kwargs)
f_error = f - f_recov
flm += s2fft.forward(f_error, **kwargs, iter=0)
return flm and running with JAX implementations as follows import jax
jax.config.update("jax_enable_x64", True)
import numpy as np
L = 128
nside = 64
method = "jax"
sampling = "healpix"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)
# Compute forward transform with iterative refinement
flm_iter = forward_iter(f, n_iter=10, L=L, method=method, sampling=sampling, nside=nside)
# Compute error in recovered f
f_recov = s2fft.inverse(flm_iter, L, nside=nside, sampling=sampling, method=method)
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}") gives output
while for
So this seems to work as intended. |
No description provided.
The text was updated successfully, but these errors were encountered: