Skip to content

Commit

Permalink
Conservative vertical regridding.
Browse files Browse the repository at this point in the history
Add a routine for conservative regridding from hybrid to sigma coordinates.

The implementation uses matrix-multiplication, which should be efficient (if memory hungry) even on accelerators.

PiperOrigin-RevId: 651069591
  • Loading branch information
shoyer authored and Dinosaur authors committed Jul 10, 2024
1 parent 1fcfb9e commit a7630f3
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 15 deletions.
85 changes: 85 additions & 0 deletions dinosaur/vertical_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,88 @@ def interp_hybrid_to_sigma(
)
regrid = lambda x: _vertical_interp_3d(desired_pressure, source_pressure, x)
return pytree_utils.tree_map_over_nonscalars(regrid, fields)


def _pressure_cell_bounds(pressure_at_cell_centers, surface_pressure):
# pressure_at_cell_centers must be non-decreasing.
# both inputs arguments should have the same units.
zero = jnp.array([0.0])
inner_boundaries = (
pressure_at_cell_centers[:-1] + pressure_at_cell_centers[1:]
) / 2
surface = jnp.array([surface_pressure])
return jnp.concatenate([zero, inner_boundaries, surface])


def _pressure_overlap(
source_centers: typing.Array,
target_centers: typing.Array,
surface_pressure: typing.Array,
) -> jnp.ndarray:
"""Calculate the interval overlap between pressure grid cells."""
source_bounds = _pressure_cell_bounds(source_centers, surface_pressure)
target_bounds = _pressure_cell_bounds(target_centers, surface_pressure)
# based on https://gist.github.com/shoyer/c0f1ddf409667650a076c058f9a17276
# (see also horizontal_interpolation.py)
upper = jnp.minimum(
target_bounds[1:, jnp.newaxis], source_bounds[jnp.newaxis, 1:]
)
lower = jnp.maximum(
target_bounds[:-1, jnp.newaxis], source_bounds[jnp.newaxis, :-1]
)
return jnp.maximum(upper - lower, 0)


def conservative_pressure_weights(
source_centers: typing.Array,
target_centers: typing.Array,
surface_pressure: typing.Array,
) -> jnp.ndarray:
"""Create a weight matrix for conservative regridding on pressure levels.
Args:
source_centers: 1D strictly increasing pressure levels for the source grid.
All values must be between 0 and surface_pressure.
target_centers: 1D strictly increasing pressure levels for the target grid.
All values must be between 0 and surface_pressure.
surface_pressure: surface pressure.
Returns:
NumPy array with shape (target, source). Rows sum to 1.
"""
weights = _pressure_overlap(source_centers, target_centers, surface_pressure)
weights /= jnp.sum(weights, axis=1, keepdims=True)
assert weights.shape == (target_centers.size, source_centers.size)
return weights


@functools.partial(jax.jit, static_argnums=(1, 2))
def regrid_hybrid_to_sigma(
fields: typing.Pytree,
hybrid_coords: HybridCoordinates,
sigma_coords: sigma_coordinates.SigmaCoordinates,
surface_pressure: typing.Array,
) -> typing.Pytree:
"""Conservatively regrid 3D fields from hybrid to sigma levels."""
desired_pressure = (
sigma_coords.centers[:, np.newaxis, np.newaxis] * surface_pressure
)
source_pressure = (
hybrid_coords.a_centers[:, np.newaxis, np.newaxis]
+ hybrid_coords.b_centers[:, np.newaxis, np.newaxis] * surface_pressure
)

@jax.jit
@functools.partial(
jnp.vectorize, signature='(a,x,y),(b,x,y),(x,y),(b,x,y)->(a,x,y)'
)
@functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1)
@functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1)
def _regrid_3d(x, xp, x_max, fp):
weights = conservative_pressure_weights(xp, x, x_max)
return jnp.einsum('ab,b->a', weights, fp, precision='float32')

regrid = functools.partial(
_regrid_3d, desired_pressure, source_pressure, surface_pressure
)
return pytree_utils.tree_map_over_nonscalars(regrid, fields)
59 changes: 44 additions & 15 deletions dinosaur/vertical_interpolation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

from absl.testing import absltest
from absl.testing import parameterized

from dinosaur import sigma_coordinates
from dinosaur import vertical_interpolation

import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -56,12 +54,16 @@ class PressureLevelsTest(parameterized.TestCase):
def test_get_surface_pressure(self):
levels = np.array([100, 200, 300, 400, 500])
orography = np.array([[0, 5, 10, 15]])
geopotential = np.moveaxis([[
[400, 250, 150, 50, -50],
[1000, 900, 140, 40, 20],
[500, 400, 300, 200, 100],
[600, 500, 400, 300, 200],
]], -1, 0) # reorder from [x, y, level] to [level, x, y]
geopotential = np.moveaxis(
[[
[400, 250, 150, 50, -50],
[1000, 900, 140, 40, 20],
[500, 400, 300, 200, 100],
[600, 500, 400, 300, 200],
]],
-1,
0,
) # reorder from [x, y, level] to [level, x, y]
expected = np.array([[[450, 390, 500, 550]]])
actual = vertical_interpolation.get_surface_pressure(
vertical_interpolation.PressureCoordinates(levels),
Expand All @@ -74,11 +76,15 @@ def test_get_surface_pressure(self):
def test_sigma_to_pressure_roundtrip(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(10)
pressure_coords = vertical_interpolation.PressureCoordinates(
np.array([100, 200, 300, 400]))
np.array([100, 200, 300, 400])
)
surface_pressure = np.array([[[250, 350, 450]]])
original = np.moveaxis([[[1, 2, 3, 4]] * 3], -1, 0)
on_sigma_levels = vertical_interpolation.interp_pressure_to_sigma(
original, pressure_coords, sigma_coords, surface_pressure,
original,
pressure_coords,
sigma_coords,
surface_pressure,
)
self.assertEqual(on_sigma_levels.shape, (10, 1, 3))
roundtripped = vertical_interpolation.interp_sigma_to_pressure(
Expand All @@ -90,21 +96,44 @@ def test_sigma_to_pressure_roundtrip(self):
expected[3:, :, 1] = np.nan
np.testing.assert_allclose(roundtripped, expected, atol=1e-6)

def test_hybrid_to_sigma(self):
def test_interp_hybrid_to_sigma(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5)
hybrid_coords = vertical_interpolation.HybridCoordinates(
a_centers=np.array([100, 100, 0]), b_centers=np.array([0, 0.5, 0.9]),
# at pressures [100, 600, 900]
a_centers=np.array([100, 100, 0]),
b_centers=np.array([0, 0.5, 0.9]),
)
surface_pressure = np.array([[[1000]]])
# at pressures [100, 600, 900]
original = np.array([1.0, 2.0, 3.0])[:, np.newaxis, np.newaxis]
# at pressures [100, 300, 500, 700, 900]
expected = np.array([1.0, 1.4, 1.8, 2 + 1/3, 3.0])
# linear interpolation to pressures [100, 300, 500, 700, 900]
expected = np.array([1.0, 1.4, 1.8, 2 + 1 / 3, 3.0])
actual = vertical_interpolation.interp_hybrid_to_sigma(
original, hybrid_coords, sigma_coords, surface_pressure
).ravel()
np.testing.assert_allclose(actual, expected, atol=1e-6)

def test_regrid_hybrid_to_sigma(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5)
hybrid_coords = vertical_interpolation.HybridCoordinates(
# at pressures [10, 50, 100, 300, 600, 800, 900]
a_centers=np.array([10, 50, 100, 300, 500, 500, 500]),
b_centers=np.array([0, 0, 0, 0, 0.1, 0.3, 0.4]),
)
surface_pressure = np.array([[[1000]]])
original = np.arange(1.0, 8.0)[:, np.newaxis, np.newaxis]
# area weighted averages for cells centered at [100, 300, 500, 700, 900]
expected = np.array([
1 * (30 / 200) + 2.0 * (45 / 200) + 3.0 * (125 / 200),
4.0,
4.0 * (50 / 200) + 5.0 * (150 / 200),
5.0 * (100 / 200) + 6.0 * (100 / 200),
6.0 * (50 / 200) + 7.0 * (150 / 200),
])
actual = vertical_interpolation.regrid_hybrid_to_sigma(
original, hybrid_coords, sigma_coords, surface_pressure
).ravel()
np.testing.assert_allclose(actual, expected, atol=1e-6)


if __name__ == '__main__':
absltest.main()

0 comments on commit a7630f3

Please sign in to comment.