Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate SphericalHarmonics implementations #56

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions dinosaur/coordinate_systems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ class CoordinateSystemTest(parameterized.TestCase):
dict(
horizontal=spherical_harmonic.Grid.T21(),
vertical=sigma_coordinates.SigmaCoordinates.equidistant(6)),
dict(
horizontal=spherical_harmonic.Grid.T21(
spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics
),
vertical=sigma_coordinates.SigmaCoordinates.equidistant(8)),
dict(
horizontal=spherical_harmonic.Grid.T21(),
vertical=layer_coordinates.LayerCoordinates(5)),
Expand Down
50 changes: 0 additions & 50 deletions dinosaur/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,56 +124,6 @@ def real_basis_derivative_with_zero_imag(
return j * jnp.where((i + 1) % 2, u_down, -u_up)


def complex_basis(wavenumbers: int, nodes: int) -> np.ndarray:
"""Returns the complex-valued Fourier Basis.

Args:
wavenumbers: number of wavenumbers.
nodes: number of equally spaced nodes in the range [0, 2π). Must satisfy
wavenumbers >= nodes.

Returns:
The nodes x wavenumbers matrix F, such that

F[j, k] = exp(2πi * jk / nodes) / √2π

i.e., the columns of F are the complex Fourier basis functions evenly
spaced points.

The normalization of the basis functions is chosen such that they have unit
L²([0, 2π]) norm.
"""
if wavenumbers > nodes // 2 + 1:
raise ValueError(
'`wavenumbers` must be no greater than `nodes // 2 + 1`;'
f'got wavenumbers = {wavenumbers}, nodes = {nodes}.'
)
basis = scipy.linalg.dft(nodes).conj()[:, :wavenumbers] / np.sqrt(np.pi)
basis[:, 0] /= np.sqrt(2)
return basis


def complex_basis_derivative(
u: jnp.ndarray | jax.Array, axis: int = -1
) -> jax.Array:
"""Calculate the derivative of a signal using a complex basis.

Args:
u: signal to differentiate, in the real Fourier basis.
axis: the axis along which the transform will be applied.

Returns:
The derivative of `u` along `axis`. In particular, if
`u_x = complex_basis_derivative(u)`:

u_x[..., k] = i * k * u[..., k]
"""
if axis >= 0:
raise ValueError('axis must be negative')
k = jnp.arange(u.shape[axis]).reshape((-1,) + (1,) * (-1 - axis))
return 1j * k * u


def quadrature_nodes(nodes: int) -> tuple[np.ndarray, np.ndarray]:
"""Returns nodes and weights for the trapezoidal rule.

Expand Down
45 changes: 0 additions & 45 deletions dinosaur/fourier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for fourier."""
import itertools

from absl.testing import absltest
Expand Down Expand Up @@ -68,48 +66,5 @@ def testNormalized(self, wavenumbers):
np.testing.assert_allclose((f.T * w).dot(f), eye, atol=1e-12)


class ComplexFourierTest(parameterized.TestCase):

@parameterized.parameters(
dict(wavenumbers=4, nodes=7),
dict(wavenumbers=11, nodes=21),
dict(wavenumbers=32, nodes=63),
)
def testBasis(self, wavenumbers, nodes):
f = fourier.complex_basis(wavenumbers, nodes)
for j, k in itertools.product(range(nodes), range(wavenumbers)):
normalization = np.sqrt(np.pi)
if k == 0:
normalization *= np.sqrt(2)
expected = np.exp(2 * np.pi * 1j * j * k / nodes) / normalization
np.testing.assert_allclose(f[j, k], expected, atol=1e-12)

@parameterized.parameters(
dict(wavenumbers=4, seed=0),
dict(wavenumbers=11, seed=0),
dict(wavenumbers=32, seed=0),
)
def testDerivatives(self, wavenumbers, seed):
f = np.random.RandomState(seed).normal(size=[wavenumbers])
f_x = fourier.complex_basis_derivative(f)
for k in range(wavenumbers):
np.testing.assert_allclose(f_x[k], 1j * k * f[k])

@parameterized.parameters(
dict(wavenumbers=4),
dict(wavenumbers=16),
dict(wavenumbers=256),
)
def testNormalized(self, wavenumbers):
"""Tests that the basis functions are normalized on [0, 2π]."""
nodes = 2 * wavenumbers - 1
f = fourier.complex_basis(wavenumbers, nodes)
_, w = fourier.quadrature_nodes(nodes)
expected = 2 * np.eye(wavenumbers)
expected[0, 0] = 1
norms = (f.T.conj() * w).dot(f)
np.testing.assert_allclose(norms, expected, atol=1e-12)


if __name__ == '__main__':
absltest.main()
74 changes: 65 additions & 9 deletions dinosaur/primitive_equations_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

"""Integration tests for primitive_equations."""
import dataclasses
import functools

from absl.testing import absltest
import chex

from dinosaur import coordinate_systems
from dinosaur import primitive_equations
from dinosaur import primitive_equations_states
Expand All @@ -27,7 +27,6 @@
from dinosaur import spherical_harmonic
from dinosaur import time_integration
from dinosaur import xarray_utils

import jax
from jax import config
import jax.numpy as jnp
Expand All @@ -37,17 +36,16 @@
def make_coords(
max_wavenumber: int,
num_layers: int,
mesh: jax.sharding.Mesh | None = None,
spherical_harmonics_impl: ... = (
spherical_harmonic.RealSphericalHarmonicsWithZeroImag),
spmd_mesh: jax.sharding.Mesh | None,
spherical_harmonics_impl,
) -> coordinate_systems.CoordinateSystem:
return coordinate_systems.CoordinateSystem(
spherical_harmonic.Grid.with_wavenumbers(
longitude_wavenumbers=max_wavenumber + 1,
spherical_harmonics_impl=spherical_harmonics_impl,
),
sigma_coordinates.SigmaCoordinates.equidistant(num_layers),
spmd_mesh=mesh,
spmd_mesh=spmd_mesh,
)


Expand Down Expand Up @@ -135,7 +133,12 @@ class IntegrationTest(absltest.TestCase):

def test_distributed_simulation_consistency(self):
physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si()
coords = make_coords(max_wavenumber=31, num_layers=8)
coords = make_coords(
max_wavenumber=31,
num_layers=8,
spmd_mesh=None,
spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics,
)
init_state = make_initial_state(coords, physics_specs)
sim_fn = make_dycore_sim_fn(coords, physics_specs, num_hours=1)
non_distributed_state = sim_fn(init_state)
Expand All @@ -144,7 +147,12 @@ def test_distributed_simulation_consistency(self):
with self.subTest('vertical sharding'):
devices = np.array(jax.devices()[:2]).reshape((2, 1, 1))
mesh = jax.sharding.Mesh(devices, axis_names=['z', 'x', 'y'])
distributed_coords = dataclasses.replace(coords, spmd_mesh=mesh)
distributed_coords = make_coords(
max_wavenumber=31,
num_layers=8,
spmd_mesh=mesh,
spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics,
)
distributed_init_state = pad_state(init_state, distributed_coords)
distributed_sim_fn = make_dycore_sim_fn(
distributed_coords, physics_specs, num_hours=1
Expand All @@ -160,7 +168,12 @@ def test_distributed_simulation_consistency(self):
with self.subTest('horizontal sharding'):
devices = np.array(jax.devices()[:4]).reshape((1, 2, 2))
mesh = jax.sharding.Mesh(devices, axis_names=['z', 'x', 'y'])
distributed_coords = dataclasses.replace(coords, spmd_mesh=mesh)
distributed_coords = make_coords(
max_wavenumber=31,
num_layers=8,
spmd_mesh=mesh,
spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics,
)
distributed_init_state = pad_state(init_state, distributed_coords)
distributed_sim_fn = make_dycore_sim_fn(
distributed_coords, physics_specs, num_hours=1
Expand All @@ -173,6 +186,49 @@ def test_distributed_simulation_consistency(self):
non_distributed_nodal, distributed_nodal, rtol=1e-6, range_tol=1e-6
)

def test_real_vs_fast_spherical_harmonics(self):
physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si()

real_coords = make_coords(
max_wavenumber=31,
num_layers=8,
spmd_mesh=None,
spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonics,
)
fast_coords = make_coords(
max_wavenumber=31,
num_layers=8,
spmd_mesh=None,
spherical_harmonics_impl=functools.partial(
spherical_harmonic.FastSphericalHarmonics,
transform_precision='float32',
),
)

real_init_state = make_initial_state(real_coords, physics_specs)
fast_init_state = make_initial_state(fast_coords, physics_specs)

real_init_nodal = real_coords.horizontal.to_nodal(real_init_state)
fast_init_nodal = fast_coords.horizontal.to_nodal(fast_init_state)

with self.subTest('initial conditions'):
assert_states_close(
real_init_nodal, fast_init_nodal, rtol=1e-6, range_tol=1e-6
)

sim_fn = make_dycore_sim_fn(real_coords, physics_specs, num_hours=1)
real_out_state = sim_fn(real_init_state)
real_out_nodal = real_coords.horizontal.to_nodal(real_out_state)

sim_fn = make_dycore_sim_fn(fast_coords, physics_specs, num_hours=1)
fast_out_state = sim_fn(fast_init_state)
fast_out_nodal = fast_coords.horizontal.to_nodal(fast_out_state)

with self.subTest('evolved state'):
assert_states_close(
real_out_nodal, fast_out_nodal, rtol=1e-5, range_tol=1e-5
)


if __name__ == '__main__':
chex.set_n_cpu_devices(8)
Expand Down
Loading
Loading