diff --git a/dinosaur/vertical_interpolation.py b/dinosaur/vertical_interpolation.py index d4d7f3b..3a30ded 100644 --- a/dinosaur/vertical_interpolation.py +++ b/dinosaur/vertical_interpolation.py @@ -314,3 +314,88 @@ def interp_hybrid_to_sigma( ) regrid = lambda x: _vertical_interp_3d(desired_pressure, source_pressure, x) return pytree_utils.tree_map_over_nonscalars(regrid, fields) + + +def _pressure_cell_bounds(pressure_at_cell_centers, surface_pressure): + # pressure_at_cell_centers must be non-decreasing. + # both inputs arguments should have the same units. + zero = jnp.array([0.0]) + inner_boundaries = ( + pressure_at_cell_centers[:-1] + pressure_at_cell_centers[1:] + ) / 2 + surface = jnp.array([surface_pressure]) + return jnp.concatenate([zero, inner_boundaries, surface]) + + +def _pressure_overlap( + source_centers: typing.Array, + target_centers: typing.Array, + surface_pressure: typing.Array, +) -> jnp.ndarray: + """Calculate the interval overlap between pressure grid cells.""" + source_bounds = _pressure_cell_bounds(source_centers, surface_pressure) + target_bounds = _pressure_cell_bounds(target_centers, surface_pressure) + # based on https://gist.github.com/shoyer/c0f1ddf409667650a076c058f9a17276 + # (see also horizontal_interpolation.py) + upper = jnp.minimum( + target_bounds[1:, jnp.newaxis], source_bounds[jnp.newaxis, 1:] + ) + lower = jnp.maximum( + target_bounds[:-1, jnp.newaxis], source_bounds[jnp.newaxis, :-1] + ) + return jnp.maximum(upper - lower, 0) + + +def conservative_pressure_weights( + source_centers: typing.Array, + target_centers: typing.Array, + surface_pressure: typing.Array, +) -> jnp.ndarray: + """Create a weight matrix for conservative regridding on pressure levels. + + Args: + source_centers: 1D strictly increasing pressure levels for the source grid. + All values must be between 0 and surface_pressure. + target_centers: 1D strictly increasing pressure levels for the target grid. + All values must be between 0 and surface_pressure. + surface_pressure: surface pressure. + + Returns: + NumPy array with shape (target, source). Rows sum to 1. + """ + weights = _pressure_overlap(source_centers, target_centers, surface_pressure) + weights /= jnp.sum(weights, axis=1, keepdims=True) + assert weights.shape == (target_centers.size, source_centers.size) + return weights + + +@functools.partial(jax.jit, static_argnums=(1, 2)) +def regrid_hybrid_to_sigma( + fields: typing.Pytree, + hybrid_coords: HybridCoordinates, + sigma_coords: sigma_coordinates.SigmaCoordinates, + surface_pressure: typing.Array, +) -> typing.Pytree: + """Conservatively regrid 3D fields from hybrid to sigma levels.""" + desired_pressure = ( + sigma_coords.centers[:, np.newaxis, np.newaxis] * surface_pressure + ) + source_pressure = ( + hybrid_coords.a_centers[:, np.newaxis, np.newaxis] + + hybrid_coords.b_centers[:, np.newaxis, np.newaxis] * surface_pressure + ) + + @jax.jit + @functools.partial( + jnp.vectorize, signature='(a,x,y),(b,x,y),(x,y),(b,x,y)->(a,x,y)' + ) + @functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1) + @functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1) + def _regrid_3d(x, xp, x_max, fp): + weights = conservative_pressure_weights(xp, x, x_max) + return jnp.einsum('ab,b->a', weights, fp, precision='float32') + + regrid = functools.partial( + _regrid_3d, desired_pressure, source_pressure, surface_pressure + ) + return pytree_utils.tree_map_over_nonscalars(regrid, fields) diff --git a/dinosaur/vertical_interpolation_test.py b/dinosaur/vertical_interpolation_test.py index 76bc375..9bf1baf 100644 --- a/dinosaur/vertical_interpolation_test.py +++ b/dinosaur/vertical_interpolation_test.py @@ -16,10 +16,8 @@ from absl.testing import absltest from absl.testing import parameterized - from dinosaur import sigma_coordinates from dinosaur import vertical_interpolation - import jax.numpy as jnp import numpy as np @@ -56,12 +54,16 @@ class PressureLevelsTest(parameterized.TestCase): def test_get_surface_pressure(self): levels = np.array([100, 200, 300, 400, 500]) orography = np.array([[0, 5, 10, 15]]) - geopotential = np.moveaxis([[ - [400, 250, 150, 50, -50], - [1000, 900, 140, 40, 20], - [500, 400, 300, 200, 100], - [600, 500, 400, 300, 200], - ]], -1, 0) # reorder from [x, y, level] to [level, x, y] + geopotential = np.moveaxis( + [[ + [400, 250, 150, 50, -50], + [1000, 900, 140, 40, 20], + [500, 400, 300, 200, 100], + [600, 500, 400, 300, 200], + ]], + -1, + 0, + ) # reorder from [x, y, level] to [level, x, y] expected = np.array([[[450, 390, 500, 550]]]) actual = vertical_interpolation.get_surface_pressure( vertical_interpolation.PressureCoordinates(levels), @@ -74,11 +76,15 @@ def test_get_surface_pressure(self): def test_sigma_to_pressure_roundtrip(self): sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(10) pressure_coords = vertical_interpolation.PressureCoordinates( - np.array([100, 200, 300, 400])) + np.array([100, 200, 300, 400]) + ) surface_pressure = np.array([[[250, 350, 450]]]) original = np.moveaxis([[[1, 2, 3, 4]] * 3], -1, 0) on_sigma_levels = vertical_interpolation.interp_pressure_to_sigma( - original, pressure_coords, sigma_coords, surface_pressure, + original, + pressure_coords, + sigma_coords, + surface_pressure, ) self.assertEqual(on_sigma_levels.shape, (10, 1, 3)) roundtripped = vertical_interpolation.interp_sigma_to_pressure( @@ -90,21 +96,44 @@ def test_sigma_to_pressure_roundtrip(self): expected[3:, :, 1] = np.nan np.testing.assert_allclose(roundtripped, expected, atol=1e-6) - def test_hybrid_to_sigma(self): + def test_interp_hybrid_to_sigma(self): sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5) hybrid_coords = vertical_interpolation.HybridCoordinates( - a_centers=np.array([100, 100, 0]), b_centers=np.array([0, 0.5, 0.9]), + # at pressures [100, 600, 900] + a_centers=np.array([100, 100, 0]), + b_centers=np.array([0, 0.5, 0.9]), ) surface_pressure = np.array([[[1000]]]) - # at pressures [100, 600, 900] original = np.array([1.0, 2.0, 3.0])[:, np.newaxis, np.newaxis] - # at pressures [100, 300, 500, 700, 900] - expected = np.array([1.0, 1.4, 1.8, 2 + 1/3, 3.0]) + # linear interpolation to pressures [100, 300, 500, 700, 900] + expected = np.array([1.0, 1.4, 1.8, 2 + 1 / 3, 3.0]) actual = vertical_interpolation.interp_hybrid_to_sigma( original, hybrid_coords, sigma_coords, surface_pressure ).ravel() np.testing.assert_allclose(actual, expected, atol=1e-6) + def test_regrid_hybrid_to_sigma(self): + sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5) + hybrid_coords = vertical_interpolation.HybridCoordinates( + # at pressures [10, 50, 100, 300, 600, 800, 900] + a_centers=np.array([10, 50, 100, 300, 500, 500, 500]), + b_centers=np.array([0, 0, 0, 0, 0.1, 0.3, 0.4]), + ) + surface_pressure = np.array([[[1000]]]) + original = np.arange(1.0, 8.0)[:, np.newaxis, np.newaxis] + # area weighted averages for cells centered at [100, 300, 500, 700, 900] + expected = np.array([ + 1 * (30 / 200) + 2.0 * (45 / 200) + 3.0 * (125 / 200), + 4.0, + 4.0 * (50 / 200) + 5.0 * (150 / 200), + 5.0 * (100 / 200) + 6.0 * (100 / 200), + 6.0 * (50 / 200) + 7.0 * (150 / 200), + ]) + actual = vertical_interpolation.regrid_hybrid_to_sigma( + original, hybrid_coords, sigma_coords, surface_pressure + ).ravel() + np.testing.assert_allclose(actual, expected, atol=1e-6) + if __name__ == '__main__': absltest.main()