From 6b0ee97a578e62b1616385fd95e13a1d5714634b Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 18 Nov 2024 18:44:47 +0000 Subject: [PATCH 1/3] Avoid unrolled for loops in reindexing functions --- s2fft/sampling/reindex.py | 51 +++++++++++++++------------------------ 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/s2fft/sampling/reindex.py b/s2fft/sampling/reindex.py index 97d3bb8..f6baf9c 100644 --- a/s2fft/sampling/reindex.py +++ b/s2fft/sampling/reindex.py @@ -1,6 +1,7 @@ from functools import partial import jax.numpy as jnp +import numpy as np from jax import jit @@ -36,12 +37,11 @@ def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray: """ flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - els = jnp.arange(L) - offset = els**2 + els - for el in range(L): - m_array = jnp.arange(-el, el + 1) - flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_1d[offset[el] + m_array]) - return flm_2d + row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :] + el_indices, m_indices = np.where( + (row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1] + ) + return flm_2d.at[el_indices, m_indices].set(flm_1d) @partial(jit, static_argnums=(1)) @@ -75,13 +75,11 @@ def flm_2d_to_1d_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 1D indexed harmonic coefficients. """ - flm_1d = jnp.zeros(L**2, dtype=jnp.complex128) - els = jnp.arange(L) - offset = els**2 + els - for el in range(L): - m_array = jnp.arange(-el, el + 1) - flm_1d = flm_1d.at[offset[el] + m_array].set(flm_2d[el, L - 1 + m_array]) - return flm_1d + row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :] + el_indices, m_indices = np.where( + (row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1] + ) + return flm_2d[el_indices, m_indices] @partial(jit, static_argnums=(1)) @@ -127,17 +125,12 @@ def flm_hp_to_2d_fast(flm_hp: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 2D indexed harmonic coefficients. """ - flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - - for el in range(L): - flm_2d = flm_2d.at[el, L - 1].set(flm_hp[el]) - m_array = jnp.arange(1, el + 1) - hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el - flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_hp[hp_idx]) - flm_2d = flm_2d.at[el, L - 1 - m_array].set( - (-1) ** m_array * jnp.conj(flm_hp[hp_idx]) - ) - + flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_hp.dtype) + m_indices, el_indices = np.triu_indices(n=L + 1, m=L) + flm_2d = flm_2d.at[el_indices, L - 1 + m_indices].set(flm_hp) + flm_2d = flm_2d.at[el_indices, L - 1 - m_indices].set( + (-1) ** m_indices * flm_hp.conj() + ) return flm_2d @@ -185,11 +178,5 @@ def flm_2d_to_hp_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: HEALPix indexed harmonic coefficients. """ - flm_hp = jnp.zeros(int(L * (L + 1) / 2), dtype=jnp.complex128) - - for el in range(L): - m_array = jnp.arange(el + 1) - hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el - flm_hp = flm_hp.at[hp_idx].set(flm_2d[el, L - 1 + m_array]) - - return flm_hp + m_indices, el_indices = np.triu_indices(n=L + 1, m=L) + return flm_2d[el_indices, L - 1 + m_indices] From 89f793b341d3ac540ceac72271f020a61c87668f Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 19 Nov 2024 10:26:17 +0000 Subject: [PATCH 2/3] Propagate type and fix typo in docstring --- s2fft/sampling/reindex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/s2fft/sampling/reindex.py b/s2fft/sampling/reindex.py index f6baf9c..54ea04d 100644 --- a/s2fft/sampling/reindex.py +++ b/s2fft/sampling/reindex.py @@ -8,7 +8,7 @@ @partial(jit, static_argnums=(1)) def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray: r""" - Convert from 1D indexed harmnonic coefficients to 2D indexed coefficients (JAX). + Convert from 1D indexed harmonic coefficients to 2D indexed coefficients (JAX). Note: Storage conventions for harmonic coefficients :math:`flm_{(\ell,m)}`, for @@ -36,7 +36,7 @@ def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 2D indexed harmonic coefficients. """ - flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) + flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_1d.dtype) row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :] el_indices, m_indices = np.where( (row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1] From d2c8f54b46b6b6e6e19353a5506ef93195e5b291 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 19 Nov 2024 12:06:51 +0000 Subject: [PATCH 3/3] Fix bug in flm_hp_to_2d_fast implementation --- s2fft/sampling/reindex.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/s2fft/sampling/reindex.py b/s2fft/sampling/reindex.py index 54ea04d..15e880d 100644 --- a/s2fft/sampling/reindex.py +++ b/s2fft/sampling/reindex.py @@ -126,10 +126,11 @@ def flm_hp_to_2d_fast(flm_hp: jnp.ndarray, L: int) -> jnp.ndarray: """ flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_hp.dtype) - m_indices, el_indices = np.triu_indices(n=L + 1, m=L) - flm_2d = flm_2d.at[el_indices, L - 1 + m_indices].set(flm_hp) + m_indices, el_indices = np.triu_indices(n=L, k=1, m=L) + np.array([[1], [0]]) + flm_2d = flm_2d.at[:L, L - 1].set(flm_hp[:L]) + flm_2d = flm_2d.at[el_indices, L - 1 + m_indices].set(flm_hp[L:]) flm_2d = flm_2d.at[el_indices, L - 1 - m_indices].set( - (-1) ** m_indices * flm_hp.conj() + (-1) ** m_indices * flm_hp[L:].conj() ) return flm_2d