From 11981a1ea62378c540b4323c26c1e8c4cc278f15 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 7 Nov 2024 14:34:36 -0800 Subject: [PATCH] Consolidate SphericalHarmonics implementations - Delete ComplexSphericalHarmonics - Replace the old RealSphericalHarmonics with RealSphericalHarmonicsWithZeroImag, which is what we actually use in practice (it's faster and supports parallelism) PiperOrigin-RevId: 694250875 --- dinosaur/coordinate_systems_test.py | 5 - dinosaur/fourier.py | 50 ------ dinosaur/fourier_test.py | 45 ----- .../primitive_equations_integration_test.py | 74 +++++++- dinosaur/spherical_harmonic.py | 165 +++++++----------- dinosaur/spherical_harmonic_test.py | 95 ++++------ dinosaur/xarray_utils.py | 2 +- 7 files changed, 160 insertions(+), 276 deletions(-) diff --git a/dinosaur/coordinate_systems_test.py b/dinosaur/coordinate_systems_test.py index e425b8f..e6a5df1 100644 --- a/dinosaur/coordinate_systems_test.py +++ b/dinosaur/coordinate_systems_test.py @@ -35,11 +35,6 @@ class CoordinateSystemTest(parameterized.TestCase): dict( horizontal=spherical_harmonic.Grid.T21(), vertical=sigma_coordinates.SigmaCoordinates.equidistant(6)), - dict( - horizontal=spherical_harmonic.Grid.T21( - spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics - ), - vertical=sigma_coordinates.SigmaCoordinates.equidistant(8)), dict( horizontal=spherical_harmonic.Grid.T21(), vertical=layer_coordinates.LayerCoordinates(5)), diff --git a/dinosaur/fourier.py b/dinosaur/fourier.py index eb62427..a142ad6 100644 --- a/dinosaur/fourier.py +++ b/dinosaur/fourier.py @@ -124,56 +124,6 @@ def real_basis_derivative_with_zero_imag( return j * jnp.where((i + 1) % 2, u_down, -u_up) -def complex_basis(wavenumbers: int, nodes: int) -> np.ndarray: - """Returns the complex-valued Fourier Basis. - - Args: - wavenumbers: number of wavenumbers. - nodes: number of equally spaced nodes in the range [0, 2π). Must satisfy - wavenumbers >= nodes. - - Returns: - The nodes x wavenumbers matrix F, such that - - F[j, k] = exp(2πi * jk / nodes) / √2π - - i.e., the columns of F are the complex Fourier basis functions evenly - spaced points. - - The normalization of the basis functions is chosen such that they have unit - L²([0, 2π]) norm. - """ - if wavenumbers > nodes // 2 + 1: - raise ValueError( - '`wavenumbers` must be no greater than `nodes // 2 + 1`;' - f'got wavenumbers = {wavenumbers}, nodes = {nodes}.' - ) - basis = scipy.linalg.dft(nodes).conj()[:, :wavenumbers] / np.sqrt(np.pi) - basis[:, 0] /= np.sqrt(2) - return basis - - -def complex_basis_derivative( - u: jnp.ndarray | jax.Array, axis: int = -1 -) -> jax.Array: - """Calculate the derivative of a signal using a complex basis. - - Args: - u: signal to differentiate, in the real Fourier basis. - axis: the axis along which the transform will be applied. - - Returns: - The derivative of `u` along `axis`. In particular, if - `u_x = complex_basis_derivative(u)`: - - u_x[..., k] = i * k * u[..., k] - """ - if axis >= 0: - raise ValueError('axis must be negative') - k = jnp.arange(u.shape[axis]).reshape((-1,) + (1,) * (-1 - axis)) - return 1j * k * u - - def quadrature_nodes(nodes: int) -> tuple[np.ndarray, np.ndarray]: """Returns nodes and weights for the trapezoidal rule. diff --git a/dinosaur/fourier_test.py b/dinosaur/fourier_test.py index dace2ac..74019ec 100644 --- a/dinosaur/fourier_test.py +++ b/dinosaur/fourier_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Tests for fourier.""" import itertools from absl.testing import absltest @@ -68,48 +66,5 @@ def testNormalized(self, wavenumbers): np.testing.assert_allclose((f.T * w).dot(f), eye, atol=1e-12) -class ComplexFourierTest(parameterized.TestCase): - - @parameterized.parameters( - dict(wavenumbers=4, nodes=7), - dict(wavenumbers=11, nodes=21), - dict(wavenumbers=32, nodes=63), - ) - def testBasis(self, wavenumbers, nodes): - f = fourier.complex_basis(wavenumbers, nodes) - for j, k in itertools.product(range(nodes), range(wavenumbers)): - normalization = np.sqrt(np.pi) - if k == 0: - normalization *= np.sqrt(2) - expected = np.exp(2 * np.pi * 1j * j * k / nodes) / normalization - np.testing.assert_allclose(f[j, k], expected, atol=1e-12) - - @parameterized.parameters( - dict(wavenumbers=4, seed=0), - dict(wavenumbers=11, seed=0), - dict(wavenumbers=32, seed=0), - ) - def testDerivatives(self, wavenumbers, seed): - f = np.random.RandomState(seed).normal(size=[wavenumbers]) - f_x = fourier.complex_basis_derivative(f) - for k in range(wavenumbers): - np.testing.assert_allclose(f_x[k], 1j * k * f[k]) - - @parameterized.parameters( - dict(wavenumbers=4), - dict(wavenumbers=16), - dict(wavenumbers=256), - ) - def testNormalized(self, wavenumbers): - """Tests that the basis functions are normalized on [0, 2π].""" - nodes = 2 * wavenumbers - 1 - f = fourier.complex_basis(wavenumbers, nodes) - _, w = fourier.quadrature_nodes(nodes) - expected = 2 * np.eye(wavenumbers) - expected[0, 0] = 1 - norms = (f.T.conj() * w).dot(f) - np.testing.assert_allclose(norms, expected, atol=1e-12) - - if __name__ == '__main__': absltest.main() diff --git a/dinosaur/primitive_equations_integration_test.py b/dinosaur/primitive_equations_integration_test.py index 7302b71..a7e32d2 100644 --- a/dinosaur/primitive_equations_integration_test.py +++ b/dinosaur/primitive_equations_integration_test.py @@ -14,10 +14,10 @@ """Integration tests for primitive_equations.""" import dataclasses +import functools from absl.testing import absltest import chex - from dinosaur import coordinate_systems from dinosaur import primitive_equations from dinosaur import primitive_equations_states @@ -27,7 +27,6 @@ from dinosaur import spherical_harmonic from dinosaur import time_integration from dinosaur import xarray_utils - import jax from jax import config import jax.numpy as jnp @@ -37,9 +36,8 @@ def make_coords( max_wavenumber: int, num_layers: int, - mesh: jax.sharding.Mesh | None = None, - spherical_harmonics_impl: ... = ( - spherical_harmonic.RealSphericalHarmonicsWithZeroImag), + spmd_mesh: jax.sharding.Mesh | None, + spherical_harmonics_impl, ) -> coordinate_systems.CoordinateSystem: return coordinate_systems.CoordinateSystem( spherical_harmonic.Grid.with_wavenumbers( @@ -47,7 +45,7 @@ def make_coords( spherical_harmonics_impl=spherical_harmonics_impl, ), sigma_coordinates.SigmaCoordinates.equidistant(num_layers), - spmd_mesh=mesh, + spmd_mesh=spmd_mesh, ) @@ -135,7 +133,12 @@ class IntegrationTest(absltest.TestCase): def test_distributed_simulation_consistency(self): physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si() - coords = make_coords(max_wavenumber=31, num_layers=8) + coords = make_coords( + max_wavenumber=31, + num_layers=8, + spmd_mesh=None, + spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics, + ) init_state = make_initial_state(coords, physics_specs) sim_fn = make_dycore_sim_fn(coords, physics_specs, num_hours=1) non_distributed_state = sim_fn(init_state) @@ -144,7 +147,12 @@ def test_distributed_simulation_consistency(self): with self.subTest('vertical sharding'): devices = np.array(jax.devices()[:2]).reshape((2, 1, 1)) mesh = jax.sharding.Mesh(devices, axis_names=['z', 'x', 'y']) - distributed_coords = dataclasses.replace(coords, spmd_mesh=mesh) + distributed_coords = make_coords( + max_wavenumber=31, + num_layers=8, + spmd_mesh=mesh, + spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics, + ) distributed_init_state = pad_state(init_state, distributed_coords) distributed_sim_fn = make_dycore_sim_fn( distributed_coords, physics_specs, num_hours=1 @@ -160,7 +168,12 @@ def test_distributed_simulation_consistency(self): with self.subTest('horizontal sharding'): devices = np.array(jax.devices()[:4]).reshape((1, 2, 2)) mesh = jax.sharding.Mesh(devices, axis_names=['z', 'x', 'y']) - distributed_coords = dataclasses.replace(coords, spmd_mesh=mesh) + distributed_coords = make_coords( + max_wavenumber=31, + num_layers=8, + spmd_mesh=mesh, + spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics, + ) distributed_init_state = pad_state(init_state, distributed_coords) distributed_sim_fn = make_dycore_sim_fn( distributed_coords, physics_specs, num_hours=1 @@ -173,6 +186,49 @@ def test_distributed_simulation_consistency(self): non_distributed_nodal, distributed_nodal, rtol=1e-6, range_tol=1e-6 ) + def test_real_vs_fast_spherical_harmonics(self): + physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si() + + real_coords = make_coords( + max_wavenumber=31, + num_layers=8, + spmd_mesh=None, + spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonics, + ) + fast_coords = make_coords( + max_wavenumber=31, + num_layers=8, + spmd_mesh=None, + spherical_harmonics_impl=functools.partial( + spherical_harmonic.FastSphericalHarmonics, + transform_precision='float32', + ), + ) + + real_init_state = make_initial_state(real_coords, physics_specs) + fast_init_state = make_initial_state(fast_coords, physics_specs) + + real_init_nodal = real_coords.horizontal.to_nodal(real_init_state) + fast_init_nodal = fast_coords.horizontal.to_nodal(fast_init_state) + + with self.subTest('initial conditions'): + assert_states_close( + real_init_nodal, fast_init_nodal, rtol=1e-6, range_tol=1e-6 + ) + + sim_fn = make_dycore_sim_fn(real_coords, physics_specs, num_hours=1) + real_out_state = sim_fn(real_init_state) + real_out_nodal = real_coords.horizontal.to_nodal(real_out_state) + + sim_fn = make_dycore_sim_fn(fast_coords, physics_specs, num_hours=1) + fast_out_state = sim_fn(fast_init_state) + fast_out_nodal = fast_coords.horizontal.to_nodal(fast_out_state) + + with self.subTest('evolved state'): + assert_states_close( + real_out_nodal, fast_out_nodal, rtol=1e-5, range_tol=1e-5 + ) + if __name__ == '__main__': chex.set_n_cpu_devices(8) diff --git a/dinosaur/spherical_harmonic.py b/dinosaur/spherical_harmonic.py index d5392f8..d61c7d7 100644 --- a/dinosaur/spherical_harmonic.py +++ b/dinosaur/spherical_harmonic.py @@ -81,11 +81,10 @@ class SphericalHarmonics: Attributes: longitude_wavenumbers: the maximum (exclusive) wavenumber in the longitudinal direction. Indexes along longitudinal wavenumber are - typically denoted by `m`. Must satisfy `longitude_wavenumbers <= - total_wavenumbers`. + typically denoted by `m`. total_wavenumbers: the maximum (exclusive) sum of the latitudinal and longitudinal wavenumbers. Indices along total wavenumber are typically - denoted by `l`. Must satisfy `longitude_wavenumbers <= total_wavenumbers`. + denoted by `l`. longitude_nodes: the number of nodes in the longitudinal direction. The selected nodes will be the equally spaced points in [0, 2π). latitude_nodes: the number of nodes in the latitudinal direction. The @@ -161,7 +160,21 @@ def longitudinal_derivative(self, x: Array) -> Array: class RealSphericalHarmonics(SphericalHarmonics): - """Real-valued spherical harmonics transforms.""" + """Pedagogical implementation of spherical harmonics transforms. + + This transform represents spherical harmonic (modal) coefficients as a two + dimensional grid of longtitudinal wavenumber (m) and total wavenumber (l) + values: + m = [0, +1, -1, +2, -2, ..., +M, -M] + l = [0, 1, 2, ..., L] + where `M = longitude_wavenumbers - 1` and `L = total_wavenumbers`. + + Entries with `abs(m) > l` are structural zeros, + + For better performance when using computing forward and inverse transforms, + but no guaranteed stable representation, use FastSphericalHarmonics, which + also supports parallelism. + """ @functools.cached_property def nodal_axes(self) -> tuple[np.ndarray, np.ndarray]: @@ -324,9 +337,9 @@ def _fourier_derivative_for_real_basis_with_zero_imag( if mesh is None: return fourier.real_basis_derivative_with_zero_imag(x, axis=-2) - # RealSphericalHarmonicsWithZeroImage always pads longitudinal frequencies by - # a multiple of two times the number of X shards, so we can safely - # differentiate without any distributed communication. + # FastHarmonicsWithZeroImage always pads longitudinal frequencies by a + # multiple of two times the number of X shards, so we can safely differentiate + # without any distributed communication. def differentiate(u): axis = -2 @@ -352,6 +365,9 @@ def _transform_einsum( precision: str, ) -> jax.Array: """einsum for calculating Fourier and Legendre transforms.""" + if mesh is None: + return jnp.einsum(subscripts, lhs, rhs, precision=precision) + out_ndim = len( jax.eval_shape(functools.partial(jnp.einsum, subscripts), lhs, rhs).shape ) @@ -370,7 +386,10 @@ def _transform_einsum( in_spec = P(z, None, 'x', 'y') if rhs.ndim == 4 else P(z, 'x', 'y') out_spec = P(z, None, 'x', 'y') if out_ndim == 4 else P(z, 'x', 'y') else: - raise ValueError(f'only 0 or 1 dimensions supported for ...: {subscripts}') + raise ValueError( + 'only 0 or 1 dimensions supported for ... when using a mesh:' + f' {subscripts}' + ) return jax_numpy_utils.sharded_einsum( subscripts, @@ -385,12 +404,16 @@ def _transform_einsum( @dataclasses.dataclass(frozen=True) -class RealSphericalHarmonicsWithZeroImag(SphericalHarmonics): - """Real-valued spherical harmonics with an extra imaginary part for m=-0. +class FastSphericalHarmonics(SphericalHarmonics): + """Fast implementation of spherical harmonic transformation. + + No stability guarantees are made about the shapes of arrays in the modal + representation. - This can be more efficient because the array of Legendre transform - coefficients is the same for positive and negative coefficients, so this - halves the size of the `p` array on the MXU. + Currently uses an extra imaginary term for m=-0. This can be more efficient + because the array of Legendre transform coefficients is the same for positive + and negative coefficients, so this halves the size of the `p` array on the + MXU. This version of spherical harmonics also supports model parallelism, if `spmd_mesh` is provided. The additional optional arguments allow for low-level @@ -498,6 +521,25 @@ def mask(self) -> np.ndarray: @functools.cached_property def basis(self) -> _SphericalHarmonicBasis: + # The product of the arrays `f` and `p` gives the real normalized spherical + # harmonic basis evaluated on a grid of longitudes λ and latitudes θ: + # + # f[i, 2m ] p[2m, j, l] = cₗₘ cos(m λᵢ) Pᵐₗ(sin θⱼ) + # f[i, 2m + 1] p[2m + 1, j, l] = cₗₘ sin(m λᵢ) Pᵐₗ(sin θⱼ) + # + # where the constants cₗₘ are chosen such that each function has unit L² + # norm on the unit sphere. The longitudes λᵢ are `longitude_nodes` equally + # spaced points in [0, 2π). The latitude nodes θⱼ are chosen such that + # (sin θⱼ) are the Gauss-Legendre quadrature points if + # `latitude_spacing = 'gauss'`, or θⱼ are `latitude_nodes` equally spaced + # points if `latitude_spacing = 'equiangular'` (or + # `'equiangular_with_poles'` for equally spaced points including points at + # the poles). + # + # The shapes of the returned arrays are + # + # f.shape == (longitude_nodes, 2*longitude_wavenumbers) + # p.shape == (2*longitude_wavenumbers, latitude_nodes, total_wavenumbers) nodal_pad_x, nodal_pad_y = self.nodal_padding modal_pad_x, modal_pad_y = self.modal_padding @@ -535,6 +577,8 @@ def inverse_transform(self, x): 'mjl,...sml->...smj', p, x, mesh, *einsum_args ) if self.stacked_fourier_transforms: + # note: explicit matrix multiplication seems to be faster than using an + # explicit FFT at the resolutions we use. x = jax.named_call(_transform_einsum, name='inv_fourier')( 'ism,...smj->...ij', f, x, mesh, *einsum_args ) @@ -573,94 +617,8 @@ def longitudinal_derivative(self, x: Array) -> Array: @dataclasses.dataclass(frozen=True) -class ComplexSphericalHarmonics(SphericalHarmonics): - """Complex valued spherical harmonics transforms. - - This works fine, but in practice is considerably slower (at least on TPUs) - than real-values spherical harmonics transformations, probably because XLA's - code generation for complex numbers is not well optimized. - """ - - @functools.cached_property - def nodal_axes(self) -> tuple[np.ndarray, np.ndarray]: - longitude, _ = fourier.quadrature_nodes(self.longitude_nodes) - sin_latitude, _ = get_latitude_nodes( - self.latitude_nodes, self.latitude_spacing - ) - return longitude, sin_latitude - - @functools.cached_property - def nodal_shape(self) -> tuple[int, int]: - return (self.longitude_nodes, self.latitude_nodes) - - @functools.cached_property - def nodal_padding(self) -> tuple[int, int]: - return (0, 0) - - @functools.cached_property - def modal_axes(self) -> tuple[np.ndarray, np.ndarray]: - lon_wavenumbers = np.arange(self.longitude_wavenumbers) - tot_wavenumbers = np.arange(self.total_wavenumbers) - return lon_wavenumbers, tot_wavenumbers - - @functools.cached_property - def modal_shape(self) -> tuple[int, int]: - return (self.longitude_wavenumbers, self.total_wavenumbers) - - @functools.cached_property - def modal_padding(self) -> tuple[int, int]: - return (0, 0) - - @functools.cached_property - def modal_dtype(self) -> np.dtype: - return np.dtype(np.complex64) - - @functools.cached_property - def mask(self) -> np.ndarray: - m, l = np.meshgrid(*self.modal_axes, indexing='ij') - return m <= l - - @functools.cached_property - def basis(self) -> _SphericalHarmonicBasis: - f = fourier.complex_basis( - wavenumbers=self.longitude_wavenumbers, - nodes=self.longitude_nodes, - ) - _, wf = fourier.quadrature_nodes(self.longitude_nodes) - x, wp = get_latitude_nodes(self.latitude_nodes, self.latitude_spacing) - w = wf * wp - p = associated_legendre.evaluate( - n_m=self.longitude_wavenumbers, n_l=self.total_wavenumbers, x=x - ) - return _SphericalHarmonicBasis(f=f, p=p, w=w) - - def inverse_transform(self, x): - p = self.basis.p - f = self.basis.f - px = jax.named_call(einsum, name='inv_legendre')('mjl,...ml->...mj', p, x) - fpx_from_real = jax.named_call(einsum, name='inv_fourier_from_real')( - 'im,...mj->...ij', jnp.real(f), jnp.real(px) - ) - fpx_from_imag = jax.named_call(einsum, name='inv_fourier_from_imag')( - 'im,...mj->...ij', -jnp.imag(f), jnp.imag(px) - ) - return fpx_from_real + fpx_from_imag - - def transform(self, x): - w = self.basis.w - f = self.basis.f - p = self.basis.p - wx = w * x - fwx = jax.named_call(einsum, name='fwd_fourier')( - 'im,...ij->...mj', jnp.conj(f), wx - ) - pfwx = jax.named_call(einsum, name='fwd_legendre')( - 'mjl,...mj->...ml', p, fwx - ) - return pfwx - - def longitudinal_derivative(self, x: Array) -> Array: - return fourier.complex_basis_derivative(x, axis=-2) +class RealSphericalHarmonicsWithZeroImag(FastSphericalHarmonics): + """Deprecated alias for `FastSphericalHarmonics`.""" def _vertical_pad( @@ -697,6 +655,7 @@ def _with_vertical_padding( Returns: Function that can be applied to non-padded arrays. """ + def g(x): x, padding = _vertical_pad(x, mesh) return _vertical_crop(f(x), padding) @@ -772,9 +731,7 @@ def __post_init__(self): "mesh is missing one or more of the required axis names 'x' and " f"'y': {self.spmd_mesh}" ) - assert isinstance( - self.spherical_harmonics, RealSphericalHarmonicsWithZeroImag - ) + assert isinstance(self.spherical_harmonics, FastSphericalHarmonics) @classmethod def with_wavenumbers( diff --git a/dinosaur/spherical_harmonic_test.py b/dinosaur/spherical_harmonic_test.py index 4c27e89..f605e1d 100644 --- a/dinosaur/spherical_harmonic_test.py +++ b/dinosaur/spherical_harmonic_test.py @@ -17,9 +17,7 @@ import functools from absl.testing import absltest from absl.testing import parameterized - from dinosaur import spherical_harmonic - import jax from jax import config import jax.numpy as jnp @@ -48,39 +46,32 @@ def random_modal_state(grid, seed=0): class SphericalHarmonicTest(parameterized.TestCase): - @parameterized.product( - params=[ - dict( - longitude_nodes=64, - latitude_nodes=32, - longitude_wavenumbers=32, - total_wavenumbers=32, - latitude_spacing='gauss', - ), - dict( - longitude_nodes=117, - latitude_nodes=13, - longitude_wavenumbers=45, - total_wavenumbers=123, - latitude_spacing='equiangular', - ), - dict( - longitude_nodes=117, - latitude_nodes=13, - longitude_wavenumbers=45, - total_wavenumbers=123, - latitude_spacing='equiangular_with_poles', - ), - ], - impl=[ - # RealSphericalHarmonicsWithZeroImag uses a different convention - spherical_harmonic.RealSphericalHarmonics, - spherical_harmonic.ComplexSphericalHarmonics, - ], + @parameterized.parameters( + dict( + longitude_nodes=64, + latitude_nodes=32, + longitude_wavenumbers=32, + total_wavenumbers=32, + latitude_spacing='gauss', + ), + dict( + longitude_nodes=117, + latitude_nodes=13, + longitude_wavenumbers=45, + total_wavenumbers=123, + latitude_spacing='equiangular', + ), + dict( + longitude_nodes=117, + latitude_nodes=13, + longitude_wavenumbers=45, + total_wavenumbers=123, + latitude_spacing='equiangular_with_poles', + ), ) - def testBasisShapes(self, params, impl): + def testBasisShapes(self, **params): """Tests that the arrays provided by `basis` have the expected shape.""" - spherical_harmonics = impl(**params) + spherical_harmonics = spherical_harmonic.RealSphericalHarmonics(**params) basis = spherical_harmonics.basis longitude_nodes = params['longitude_nodes'] latitude_nodes = params['latitude_nodes'] @@ -99,8 +90,7 @@ class GridTest(parameterized.TestCase): latitude_spacing=('gauss', 'equiangular'), impl=[ spherical_harmonic.RealSphericalHarmonics, - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, - spherical_harmonic.ComplexSphericalHarmonics, + spherical_harmonic.FastSphericalHarmonics, ], ) def testGridShape(self, wavenumbers, latitude_spacing, impl): @@ -179,7 +169,7 @@ def testConstructors(self): latitude_spacing='equiangular', jit=True, seed=0, - spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonicsWithZeroImag, + spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics, ), dict( longitude_wavenumbers=64, @@ -188,19 +178,11 @@ def testConstructors(self): jit=True, seed=0, spherical_harmonics_impl=functools.partial( - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, + spherical_harmonic.FastSphericalHarmonics, base_shape_multiple=8, reverse_einsum_arg_order=True, ), ), - dict( - longitude_wavenumbers=64, - total_wavenumbers=64, - latitude_spacing='equiangular_with_poles', - jit=True, - seed=0, - spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics, - ), ) def testRoundTrip( self, @@ -242,8 +224,7 @@ def testRoundTrip( seed=(0,), impl=[ spherical_harmonic.RealSphericalHarmonics, - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, - spherical_harmonic.ComplexSphericalHarmonics, + spherical_harmonic.FastSphericalHarmonics, ], ) def testLaplacianRoundTrip(self, wavenumbers, latitude_spacing, seed, impl): @@ -267,8 +248,7 @@ def testLaplacianRoundTrip(self, wavenumbers, latitude_spacing, seed, impl): test_function=(_function_0, _function_1), impl=[ spherical_harmonic.RealSphericalHarmonics, - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, - spherical_harmonic.ComplexSphericalHarmonics, + spherical_harmonic.FastSphericalHarmonics, ], ) def testDerivatives( @@ -405,7 +385,7 @@ def testLaplacian(self, grid, seed): dict( grid=spherical_harmonic.Grid.with_wavenumbers( 128, - spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonicsWithZeroImag, + spherical_harmonics_impl=spherical_harmonic.FastSphericalHarmonics, ), atol=1e-11, seed=0, @@ -414,21 +394,13 @@ def testLaplacian(self, grid, seed): grid=spherical_harmonic.Grid.with_wavenumbers( 128, spherical_harmonics_impl=functools.partial( - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, + spherical_harmonic.FastSphericalHarmonics, transform_precision='float32', - ) + ), ), atol=1e-11, seed=0, ), - dict( - grid=spherical_harmonic.Grid.with_wavenumbers( - 128, - spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics, - ), - atol=1e-10, - seed=0, - ), dict( grid=spherical_harmonic.Grid( longitude_wavenumbers=64, @@ -541,8 +513,7 @@ def testIntegrationSurfaceArea(self, wavenumbers, latitude_spacing, radius): ], impl=[ spherical_harmonic.RealSphericalHarmonics, - spherical_harmonic.RealSphericalHarmonicsWithZeroImag, - spherical_harmonic.ComplexSphericalHarmonics, + spherical_harmonic.FastSphericalHarmonics, ], ) def testIntegrationSphericalHarmonics(self, params, impl): diff --git a/dinosaur/xarray_utils.py b/dinosaur/xarray_utils.py index f25cfef..1761b19 100644 --- a/dinosaur/xarray_utils.py +++ b/dinosaur/xarray_utils.py @@ -150,7 +150,7 @@ 'RealSphericalHarmonicsWithZeroImag': ( spherical_harmonic.RealSphericalHarmonicsWithZeroImag ), - 'ComplexSphericalHarmonics': spherical_harmonic.ComplexSphericalHarmonics, + 'FastSphericalHarmonics': spherical_harmonic.FastSphericalHarmonics, }