Skip to content

Commit

Permalink
Gaussian Perturbation from HENS paper (#152)
Browse files Browse the repository at this point in the history
* new branch based on oss main. otherwise identical to hens_1

* added method to perurbation ini file

* add typing

---------

Co-authored-by: MarsuPila <22983240+MarsuPila@users.noreply.github.com>
Co-authored-by: Dallas Foster <dallasf@nvidia.com>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 45f1239 commit 90e37ba
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 1 deletion.
2 changes: 1 addition & 1 deletion earth2studio/perturbation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .base import Perturbation
from .brown import Brown # noqa
from .bv import BredVector # noqa
from .gaussian import Gaussian # noqa
from .gaussian import CorrelatedSphericalGaussian, Gaussian # noqa
from .lagged import LaggedEnsemble # noqa
from .spherical import SphericalGaussian # noqa
from .zero import Zero # noqa
256 changes: 256 additions & 0 deletions earth2studio/perturbation/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import numpy as np
import torch
from torch_harmonics import InverseRealSHT
from typing_extensions import Self

from earth2studio.utils import handshake_dim
from earth2studio.utils.type import CoordSystem


Expand Down Expand Up @@ -58,3 +64,253 @@ def __call__(
"""
noise_amplitude = self.noise_amplitude.to(x.device)
return x + noise_amplitude * torch.randn_like(x), coords


class CorrelatedSphericalGaussian:
"""Produces Gaussian random field on the sphere with Matern
covariance peturbation method output to a lat lon grid
Warning
-------
Presently this method generates noise on equirectangular grid of size [N, 2*N] when
N is even or [N+1, 2*N] when N is odd.
Parameters
----------
noise_amplitude : float | Tensor, optional
Noise amplitude, by default 0.05. If a tensor,
this must be broadcastable with the input data.
alpha : float, optional
Regularity parameter. Larger means smoother, by default 2.0
tau : float, optional
Length-scale parameter. Larger means more scales, by default 3.0
sigma : Union[float, None], optional
Scale parameter. If None, sigma = tau**(0.5*(2*alpha - 2.0)), by default None
"""

def __init__(
self,
noise_amplitude: float | torch.Tensor | None = None,
sigma: float = 1.0,
length_scale: float = 5.0e5,
time_scale: float = 48.0,
) -> None:
if noise_amplitude is None:
raise ValueError("pass noise amplitude")
self.sigma = sigma
self.length_scale = length_scale
self.time_scale = time_scale
self.noise_amplitude = (
noise_amplitude
if isinstance(noise_amplitude, torch.Tensor)
else torch.Tensor([noise_amplitude])
)

@torch.inference_mode()
def __call__(
self,
xx: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Apply perturbation method
Parameters
----------
x : torch.Tensor
Input tensor intended to apply perturbation on
coords : CoordSystem
Ordered dict representing coordinate system that describes the tensor, must
contain "lat" and "lon" coordinates
Returns
-------
tuple[torch.Tensor, CoordSystem]:
Output tensor and respective coordinate system dictionary
"""
shape = xx.shape

# Check the required dimensions are present
handshake_dim(coords, required_dim="lat", required_index=-2)
handshake_dim(coords, required_dim="lon", required_index=-1)

# Check the ratio
if 2 * (shape[-2] // 2) != shape[-1] / 2:
raise ValueError("Lat/lon aspect ration must be N:2N or N+1:2N")

nlat = 2 * (shape[-2] // 2) # Noise only support even lat count
sampler = CorrelatedSphericalField(
nlat=nlat,
length_scale=self.length_scale,
time_scale=self.time_scale,
sigma=self.sigma,
N=shape[-3],
)
sampler = sampler.to(xx.device)

sample_noise = sampler(xx, None)
sample_noise = sample_noise.reshape(*shape[:-2], nlat, 2 * nlat)

# Hack for odd lat coords
if xx.shape[-2] % 2 == 1:
noise = torch.zeros_like(xx)
noise[..., :-1, :] = sample_noise
noise[..., -1:, :] = noise[..., -2:-1, :]
else:
noise = sample_noise

noise_amplitude = self.noise_amplitude.to(xx.device)
return xx + noise_amplitude * noise, coords


class CorrelatedSphericalField(torch.nn.Module):
"""
This class was taken from https://github.com/ankurmahesh/earth2mip-fork/blob/HENS/earth2mip/ensemble_utils.py#L392-L531.
Reference publication: A.Mahesh et al. Huge Ensembles Part I: Design of Ensemble Weather Forecasts using Spherical Fourier Neural Operators https://arxiv.org/abs/2408.03100.
This class can be used to create noise on the sphere
with a given length scale (in km) and time scale (in hours).
It mimics the implementation of the SPPT: Stochastic Perturbed
Parameterized Tendency in this paper:
https://www.ecmwf.int/sites/default/files/elibrary/2009/11577-stochastic-parametrization-and-model-uncertainty.pdf
Parameters
----------
length_scale : int
Length scale in km
time_scale : int
Time scale for the AR(1) process, that governs
the evolution of the coefficients
sigma: desired standard deviation of the field in
grid point space
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""

def __init__(
self,
nlat: int,
length_scale: float,
time_scale: float,
sigma: float,
N: int,
grid: str = "equiangular",
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.sigma = sigma
dt = 6.0
self.phi = np.exp(-dt / time_scale)

# Number of latitudinal modes.
self.nlat = nlat

# Inverse SHT
self.isht = InverseRealSHT(
self.nlat, 2 * self.nlat, grid=grid, norm="backward"
).to(dtype=dtype)

r_earth = 6.371e6
# kT is defined on slide 7
self.kT = (length_scale / r_earth) ** 2 / 2
F0 = self.calculateF0(self.sigma, self.phi, self.nlat, self.kT)

prods = (
torch.tensor([j * (j + 1) for j in range(0, self.nlat)])
.view(self.nlat, 1)
.repeat(1, self.nlat + 1)
)

sigma_n = torch.tril(torch.exp(-self.kT * prods / 2) * F0)
self.register_buffer("sigma_n", sigma_n)

# Save mean and var of the standard Gaussian.
# Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
self.register_buffer("mean", mean)
self.register_buffer("var", var)
self.N = N

# Standard normal noise sampler.
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
xi = self.gaussian_noise.sample(
torch.Size((self.N, self.nlat, self.nlat + 1, 2))
).squeeze()
xi = torch.view_as_complex(xi)

# Set specrtral cofficients to this value at initial time
# for stability in teh AR(1) process. See link in description
coeff: torch.tensor = ((1 - self.phi**2) ** (-0.5)) * self.sigma_n * xi
coeff = coeff.unsqueeze(0)
self.register_buffer("coeff", coeff)

def calculateF0(
self, sigma: float, phi: float, nlat: int, kT: float
) -> torch.Tensor:
"""
This function scales the coefficients such that their
grid-point standard deviation is sigma.
sigma is the desired variance
phi is a np.exp(-dt/time_scale)
"""
numerator = sigma**2 * (1 - (phi**2))
wavenumbers = torch.arange(1, nlat)
denominator = (2 * wavenumbers + 1) * torch.exp(
-kT * wavenumbers * (wavenumbers + 1)
)
denominator = 2 * torch.Tensor(denominator).sum()

return (numerator / denominator) ** 0.5

def forward(self, xx: torch.Tensor, time: np.datetime64 = None) -> torch.Tensor:
"""
Generate and return a field with a correlated length scale.
Update the coefficients using an AR(1) process.
"""
noises = []
# iterate over samples in batch
for _ in range(xx.shape[0]):
noise = self.isht(self.coeff) * 4 * np.pi # type: ignore
noises.append(noise.reshape(1, 1, 1, self.N, self.nlat, self.nlat * 2))

# Sample Gaussian noise. # TODO why??? for next step maybe?
xi = self.gaussian_noise.sample(
torch.Size((self.N, self.nlat, self.nlat + 1, 2))
).squeeze()
xi = torch.view_as_complex(xi)

self.coeff = (self.phi * self.coeff) + (self.sigma_n * xi) # type: ignore

return torch.cat(noises)

# Override cuda and to methods so sampler gets initialized with mean
# and variance on the correct device.
def cuda(self, *args: Any, **kwargs: Any) -> Self:
"""
to GPU
"""
super().cuda(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)

return self

def to(self, *args: Any, **kwargs: Any) -> Self:
"""
to(*args, **kwargs)
"""
super().to(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)

return self

0 comments on commit 90e37ba

Please sign in to comment.