Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conservative vertical regridding. #43

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions dinosaur/vertical_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,88 @@ def interp_hybrid_to_sigma(
)
regrid = lambda x: _vertical_interp_3d(desired_pressure, source_pressure, x)
return pytree_utils.tree_map_over_nonscalars(regrid, fields)


def _pressure_cell_bounds(pressure_at_cell_centers, surface_pressure):
# pressure_at_cell_centers must be non-decreasing.
# both inputs arguments should have the same units.
zero = jnp.array([0.0])
inner_boundaries = (
pressure_at_cell_centers[:-1] + pressure_at_cell_centers[1:]
) / 2
surface = jnp.array([surface_pressure])
return jnp.concatenate([zero, inner_boundaries, surface])


def _pressure_overlap(
source_centers: typing.Array,
target_centers: typing.Array,
surface_pressure: typing.Array,
) -> jnp.ndarray:
"""Calculate the interval overlap between pressure grid cells."""
source_bounds = _pressure_cell_bounds(source_centers, surface_pressure)
target_bounds = _pressure_cell_bounds(target_centers, surface_pressure)
# based on https://gist.github.com/shoyer/c0f1ddf409667650a076c058f9a17276
# (see also horizontal_interpolation.py)
upper = jnp.minimum(
target_bounds[1:, jnp.newaxis], source_bounds[jnp.newaxis, 1:]
)
lower = jnp.maximum(
target_bounds[:-1, jnp.newaxis], source_bounds[jnp.newaxis, :-1]
)
return jnp.maximum(upper - lower, 0)


def conservative_pressure_weights(
source_centers: typing.Array,
target_centers: typing.Array,
surface_pressure: typing.Array,
) -> jnp.ndarray:
"""Create a weight matrix for conservative regridding on pressure levels.

Args:
source_centers: 1D strictly increasing pressure levels for the source grid.
All values must be between 0 and surface_pressure.
target_centers: 1D strictly increasing pressure levels for the target grid.
All values must be between 0 and surface_pressure.
surface_pressure: surface pressure.

Returns:
NumPy array with shape (target, source). Rows sum to 1.
"""
weights = _pressure_overlap(source_centers, target_centers, surface_pressure)
weights /= jnp.sum(weights, axis=1, keepdims=True)
assert weights.shape == (target_centers.size, source_centers.size)
return weights


@functools.partial(jax.jit, static_argnums=(1, 2))
def regrid_hybrid_to_sigma(
fields: typing.Pytree,
hybrid_coords: HybridCoordinates,
sigma_coords: sigma_coordinates.SigmaCoordinates,
surface_pressure: typing.Array,
) -> typing.Pytree:
"""Conservatively regrid 3D fields from hybrid to sigma levels."""
desired_pressure = (
sigma_coords.centers[:, np.newaxis, np.newaxis] * surface_pressure
)
source_pressure = (
hybrid_coords.a_centers[:, np.newaxis, np.newaxis]
+ hybrid_coords.b_centers[:, np.newaxis, np.newaxis] * surface_pressure
)

@jax.jit
@functools.partial(
jnp.vectorize, signature='(a,x,y),(b,x,y),(x,y),(b,x,y)->(a,x,y)'
)
@functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1)
@functools.partial(jax.vmap, in_axes=(-1, -1, -1, -1), out_axes=-1)
def _regrid_3d(x, xp, x_max, fp):
weights = conservative_pressure_weights(xp, x, x_max)
return jnp.einsum('ab,b->a', weights, fp, precision='float32')

regrid = functools.partial(
_regrid_3d, desired_pressure, source_pressure, surface_pressure
)
return pytree_utils.tree_map_over_nonscalars(regrid, fields)
59 changes: 44 additions & 15 deletions dinosaur/vertical_interpolation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

from absl.testing import absltest
from absl.testing import parameterized

from dinosaur import sigma_coordinates
from dinosaur import vertical_interpolation

import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -56,12 +54,16 @@ class PressureLevelsTest(parameterized.TestCase):
def test_get_surface_pressure(self):
levels = np.array([100, 200, 300, 400, 500])
orography = np.array([[0, 5, 10, 15]])
geopotential = np.moveaxis([[
[400, 250, 150, 50, -50],
[1000, 900, 140, 40, 20],
[500, 400, 300, 200, 100],
[600, 500, 400, 300, 200],
]], -1, 0) # reorder from [x, y, level] to [level, x, y]
geopotential = np.moveaxis(
[[
[400, 250, 150, 50, -50],
[1000, 900, 140, 40, 20],
[500, 400, 300, 200, 100],
[600, 500, 400, 300, 200],
]],
-1,
0,
) # reorder from [x, y, level] to [level, x, y]
expected = np.array([[[450, 390, 500, 550]]])
actual = vertical_interpolation.get_surface_pressure(
vertical_interpolation.PressureCoordinates(levels),
Expand All @@ -74,11 +76,15 @@ def test_get_surface_pressure(self):
def test_sigma_to_pressure_roundtrip(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(10)
pressure_coords = vertical_interpolation.PressureCoordinates(
np.array([100, 200, 300, 400]))
np.array([100, 200, 300, 400])
)
surface_pressure = np.array([[[250, 350, 450]]])
original = np.moveaxis([[[1, 2, 3, 4]] * 3], -1, 0)
on_sigma_levels = vertical_interpolation.interp_pressure_to_sigma(
original, pressure_coords, sigma_coords, surface_pressure,
original,
pressure_coords,
sigma_coords,
surface_pressure,
)
self.assertEqual(on_sigma_levels.shape, (10, 1, 3))
roundtripped = vertical_interpolation.interp_sigma_to_pressure(
Expand All @@ -90,21 +96,44 @@ def test_sigma_to_pressure_roundtrip(self):
expected[3:, :, 1] = np.nan
np.testing.assert_allclose(roundtripped, expected, atol=1e-6)

def test_hybrid_to_sigma(self):
def test_interp_hybrid_to_sigma(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5)
hybrid_coords = vertical_interpolation.HybridCoordinates(
a_centers=np.array([100, 100, 0]), b_centers=np.array([0, 0.5, 0.9]),
# at pressures [100, 600, 900]
a_centers=np.array([100, 100, 0]),
b_centers=np.array([0, 0.5, 0.9]),
)
surface_pressure = np.array([[[1000]]])
# at pressures [100, 600, 900]
original = np.array([1.0, 2.0, 3.0])[:, np.newaxis, np.newaxis]
# at pressures [100, 300, 500, 700, 900]
expected = np.array([1.0, 1.4, 1.8, 2 + 1/3, 3.0])
# linear interpolation to pressures [100, 300, 500, 700, 900]
expected = np.array([1.0, 1.4, 1.8, 2 + 1 / 3, 3.0])
actual = vertical_interpolation.interp_hybrid_to_sigma(
original, hybrid_coords, sigma_coords, surface_pressure
).ravel()
np.testing.assert_allclose(actual, expected, atol=1e-6)

def test_regrid_hybrid_to_sigma(self):
sigma_coords = sigma_coordinates.SigmaCoordinates.equidistant(5)
hybrid_coords = vertical_interpolation.HybridCoordinates(
# at pressures [10, 50, 100, 300, 600, 800, 900]
a_centers=np.array([10, 50, 100, 300, 500, 500, 500]),
b_centers=np.array([0, 0, 0, 0, 0.1, 0.3, 0.4]),
)
surface_pressure = np.array([[[1000]]])
original = np.arange(1.0, 8.0)[:, np.newaxis, np.newaxis]
# area weighted averages for cells centered at [100, 300, 500, 700, 900]
expected = np.array([
1 * (30 / 200) + 2.0 * (45 / 200) + 3.0 * (125 / 200),
4.0,
4.0 * (50 / 200) + 5.0 * (150 / 200),
5.0 * (100 / 200) + 6.0 * (100 / 200),
6.0 * (50 / 200) + 7.0 * (150 / 200),
])
actual = vertical_interpolation.regrid_hybrid_to_sigma(
original, hybrid_coords, sigma_coords, surface_pressure
).ravel()
np.testing.assert_allclose(actual, expected, atol=1e-6)


if __name__ == '__main__':
absltest.main()
Loading