Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Improvements 0 #151

Merged
merged 11 commits into from
Nov 13, 2024
2 changes: 1 addition & 1 deletion earth2studio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
from .rx import CosineSolarZenith, LandSeaMask, SurfaceGeoPotential
from .utils import datasource_to_file, fetch_data, prep_data_array
from .wb2 import WB2ERA5, WB2Climatology, WB2ERA5_32x64, WB2ERA5_121x240
from .xr import DataArrayFile, DataSetFile
from .xr import DataArrayDirectory, DataArrayFile, DataSetFile
9 changes: 7 additions & 2 deletions earth2studio/data/arco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class ARCO:
Cache data source on local memory, by default True
verbose : bool, optional
Print download progress, by default True
async_timeout : int, optional
Time in sec after which download will be cancelled if not finished successfully,
by default 600

Warning
-------
Expand All @@ -68,7 +71,9 @@ class ARCO:
ARCO_LAT = np.linspace(90, -90, 721)
ARCO_LON = np.linspace(0, 359.75, 1440)

def __init__(self, cache: bool = True, verbose: bool = True):
def __init__(
self, cache: bool = True, verbose: bool = True, async_timeout: int = 600
):
self._cache = cache
self._verbose = verbose

Expand All @@ -90,7 +95,7 @@ def __init__(self, cache: bool = True, verbose: bool = True):
"gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", fs
)
self.zarr_group = zarr.open(fs_map, mode="r")
self.async_timeout = 600
self.async_timeout = async_timeout
self.async_process_limit = 4

def __call__(
Expand Down
6 changes: 6 additions & 0 deletions earth2studio/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def datasource_to_file(
lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]),
backend: Literal["netcdf", "zarr"] = "netcdf",
chunks: dict[str, int] = {"variable": 1},
dtype: np.dtype | None = None,
) -> None:
"""Utility function that can be used for building a local data store needed
for an inference request. This file can then be used with the
Expand All @@ -200,6 +201,8 @@ def datasource_to_file(
Storage backend to save output file as, by default "netcdf"
chunks : dict[str, int], optional
Chunk sizes along each dimension, by default {"variable": 1}
dtype : np.dtype, optional
Data type for storing data
"""
if isinstance(time, datetime):
time = [time]
Expand All @@ -221,6 +224,9 @@ def datasource_to_file(
da = da.assign_coords(time=time)
da = da.chunk(chunks=chunks)

if dtype is not None:
da = da.astype(dtype=dtype)

match backend:
case "netcdf":
da.to_netcdf(file_name)
Expand Down
75 changes: 75 additions & 0 deletions earth2studio/data/xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime
from typing import Any

import xarray as xr
from numpy import ndarray
from pandas import to_datetime

from earth2studio.utils.type import TimeArray, VariableArray

Expand All @@ -35,6 +38,7 @@ class DataArrayFile:
def __init__(self, file_path: str, **xr_args: Any):
self.file_path = file_path
self.da = xr.open_dataarray(self.file_path, **xr_args)
# self.da = xr.open_dataarray(self.file_path, **xr_args)

def __call__(
self,
Expand Down Expand Up @@ -94,3 +98,74 @@ def __call__(
Loaded data array
"""
return self.da.sel(time=time, variable=variable)


class DataArrayDirectory:
"""A local xarray dataarray directory data source. This file should be compatable with
xarray. For example, a netCDF file. the structure of the directory should be like
path/to/monthly/files
|___2020
| |___2020_01.nc
| |___2020_02.nc
| |___ ...
|
|___2021
|___2021_01.nc
|___...

Parameters
----------
file_path : str
Path to xarray data array compatible file.
xr_args : Any
Keyword arguments to send to the xarray opening method.
"""

def __init__(self, dir_path: str, **xr_args: Any):
self.dir_path = dir_path
self.das: dict[str, dict[str, xr.DataArray]] = {}
for yr in os.listdir(self.dir_path):
yr_dir = os.path.join(self.dir_path, yr)
if os.path.isdir(yr_dir):
self.das[yr] = {}
for fl in os.listdir(yr_dir):
pth = os.path.join(yr_dir, fl)
if os.path.isfile(pth):
try:
arr = xr.open_dataarray(pth, **xr_args)
except: # noqa
continue
mon = fl.split(".")[0].split("_")[-1]
self.das[yr][mon] = arr

def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data.

Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for.
variable : str | list[str] | VariableArray
Strings or list of strings that refer to variables to return.

Returns
-------
xr.DataArray
Loaded data array
"""
if not (isinstance(time, list) or isinstance(time, ndarray)):
time = [time]
if not (isinstance(variable, list) or isinstance(variable, ndarray)):
variable = [variable]

arrs = []
for tt in time:
yr = str(to_datetime(tt).year)
mon = str(to_datetime(tt).month).zfill(2)
arrs.append(self.das[yr][mon].sel(time=tt, variable=variable))

return xr.concat(arrs, dim="time")
4 changes: 2 additions & 2 deletions earth2studio/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class XarrayBackend:

"""

def __init__(self, coords: CoordSystem, **xr_kwargs: Any) -> None:
def __init__(self, coords: CoordSystem = OrderedDict({}), **xr_kwargs: Any) -> None:
self.root = xr.Dataset(data_vars={}, coords=coords, **xr_kwargs)
self.coords: CoordSystem = OrderedDict({})
self.coords = coords

def __contains__(self, item: str) -> bool:
"""Checks if item in xarray Dataset.
Expand Down
1 change: 1 addition & 0 deletions earth2studio/lexicon/arco.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ARCOLexicon(metaclass=LexiconType):
"u100m": "100m_u_component_of_wind::",
"v100m": "100m_v_component_of_wind::",
"t2m": "2m_temperature::",
"d2m": "2m_dewpoint_temperature::",
"sp": "surface_pressure::",
"msl": "mean_sea_level_pressure::",
"tcwv": "total_column_water_vapour::",
Expand Down
1 change: 1 addition & 0 deletions earth2studio/lexicon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __getitem__(cls, val: str) -> tuple[str, Callable]:
"u100m": "u-component of wind at 100 m",
"v100m": "v-component of wind at 100 m",
"t2m": "temperature at 2m",
"d2m": "dewpoint temperature at 2m",
"r2m": "relative humidity at 2 m",
"q2m": "specific humidity at 2 m",
"sp": "surface pressure",
Expand Down
1 change: 1 addition & 0 deletions earth2studio/lexicon/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class CDSLexicon(metaclass=LexiconType):
"u100m": "reanalysis-era5-single-levels::100m_u_component_of_wind::",
"v100m": "reanalysis-era5-single-levels::100m_v_component_of_wind::",
"t2m": "reanalysis-era5-single-levels::2m_temperature::",
"d2m": "reanalysis-era5-single-levels::2m_dewpoint_temperature::",
"sp": "reanalysis-era5-single-levels::surface_pressure::",
"msl": "reanalysis-era5-single-levels::mean_sea_level_pressure::",
"tcwv": "reanalysis-era5-single-levels::total_column_water_vapour::",
Expand Down
2 changes: 2 additions & 0 deletions earth2studio/lexicon/gefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class GEFSLexicon(metaclass=LexiconType):
"u100m": "pgrb2b::UGRD::100 m above ground",
"v100m": "pgrb2b::VGRD::100 m above ground",
"t2m": "pgrb2a::TMP::2 m above ground",
"d2m": "pgrb2b::DPT::2 m above ground",
"r2m": "pgrb2a::RH::2 m above ground",
"q2m": "pgrb2b::SPFH::2 m above ground",
"t100m": "pgrb2b::TMP::100 m above ground",
Expand Down Expand Up @@ -261,6 +262,7 @@ class GEFSLexiconSel(metaclass=LexiconType):
"u10m": "pgrb2s::UGRD::10 m above ground",
"v10m": "pgrb2s::VGRD::10 m above ground",
"t2m": "pgrb2s::TMP::2 m above ground",
"d2m": "pgrb2s::DPT::2 m above ground",
"r2m": "pgrb2s::RH::2 m above ground",
"sp": "pgrb2s::PRES::surface",
"msl": "pgrb2s::PRMSL::mean sea level", # Pressure Reduced to MSL
Expand Down
2 changes: 1 addition & 1 deletion earth2studio/lexicon/gfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class GFSLexicon(metaclass=LexiconType):
"u100m": "UGRD::100 m above ground",
"v100m": "VGRD::100 m above ground",
"t2m": "TMP::2 m above ground",
"d2m": "DPT::2 m above ground",
"sp": "PRES::surface",
"msl": "PRMSL::mean sea level",
"tcwv": "PWAT::entire atmosphere (considered as a single layer)",
"2d": "DPT::2 m above ground",
"u1": "UGRD::1 mb",
"u2": "UGRD::2 mb",
"u3": "UGRD::3 mb",
Expand Down
1 change: 1 addition & 0 deletions earth2studio/lexicon/ifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def build_vocab() -> dict[str, str]:
"u100m": "100u::sfc::",
"v100m": "100v::sfc::",
"t2m": "2t::sfc::",
"d2m": "2d::sfc::",
"sp": "sp::sfc::",
"msl": "msl::sfc::",
"tcwv": "tcwv::sfc::",
Expand Down
1 change: 1 addition & 0 deletions earth2studio/lexicon/wb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class WB2Lexicon(metaclass=LexiconType):
"u10m": "10m_u_component_of_wind::",
"v10m": "10m_v_component_of_wind::",
"t2m": "2m_temperature::",
"d2m": "2m_dewpoint_temperature::",
"sp": "surface_pressure::",
"msl": "mean_sea_level_pressure::",
"tcwv": "total_column_water_vapour::",
Expand Down
2 changes: 2 additions & 0 deletions earth2studio/models/px/sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(
self.register_buffer("center", center)
self.register_buffer("scale", scale)
self.variables = variables
if "2d" in self.variables:
self.variables[self.variables == "2d"] = "d2m"

def __str__(self) -> str:
return "sfno_73ch_small"
Expand Down
15 changes: 7 additions & 8 deletions earth2studio/perturbation/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,18 @@ class CorrelatedSphericalField(torch.nn.Module):

Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
length_scale : int
Length scale in km

time_scale : int
Time scale for the AR(1) process, that governs
the evolution of the coefficients

sigma: desired standard deviation of the field in
grid point space

nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
sigma: float
desired standard deviation of the field in grid point space
N: int
Number of latent dimensions
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
Expand Down
2 changes: 1 addition & 1 deletion earth2studio/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from .crps import crps
from .moments import mean, std, variance # noqa
from .rank import rank_histogram # noqa
from .rmse import rmse, spread_skill_ratio # noqa
from .rmse import rmse, skill_spread, spread_skill_ratio # noqa
from .weights import lat_weight # noqa
71 changes: 71 additions & 0 deletions earth2studio/statistics/rmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch

from earth2studio.utils.coords import handshake_dim
Expand Down Expand Up @@ -260,3 +261,73 @@ def __call__(
skill, output_coords = self.reduced_rmse(em, output_coords, y, y_coords)
spread, output_coords = self.reduced_mean(*self.ensemble_var(x, x_coords))
return skill / torch.sqrt(spread), output_coords


class skill_spread(spread_skill_ratio):
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates

Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords

Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)

output_coords.update({"metric": np.array(["mse", "variance"])})
output_coords.move_to_end("metric", last=False)

return output_coords

def __call__(
self,
x: torch.Tensor,
x_coords: CoordSystem,
y: torch.Tensor,
y_coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply metric to data `x` and `y`, checking that their coordinates
are broadcastable. While reducing over `reduction_dims`.

If batch_update was passed True upon metric initialization then this method
returns the running sample MSE and variance over all seen batches.

Parameters
----------
x : torch.Tensor
The ensemble forecast input tensor. This is the tensor over which the
ensemble mean and spread are calculated with respect to.
x_coords : CoordSystem
Ordered dict representing coordinate system that describes the `x` tensor.
`reduction_dimensions` must be in coords.
y : torch.Tensor
The observation input tensor.
y_coords : CoordSystem
Ordered dict representing coordinate system that describes the `y` tensor.
`reduction_dimensions` must be in coords.

Returns
-------
tuple[torch.Tensor, CoordSystem]
Returns a tensor containing MSE and variance with appropriate reduced coordinates.
"""

em, output_coords = self.ensemble_mean(x, x_coords)
skill, output_coords = self.reduced_rmse(em, output_coords, y, y_coords)
var, output_coords = self.reduced_mean(*self.ensemble_var(x, x_coords))

mse = torch.square(skill)

output_coords.update({"metric": np.array(["mse", "variance"])})
output_coords.move_to_end("metric", last=False)

return torch.stack((mse, var)), output_coords
Loading