From ceca90e3c6318b2ac752ea26c7fa432c81c80676 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:55:09 +0200 Subject: [PATCH 1/2] expanding data source functionality --- earth2studio/data/__init__.py | 2 +- earth2studio/data/arco.py | 6 ++- earth2studio/data/utils.py | 4 ++ earth2studio/data/xr.py | 71 +++++++++++++++++++++++++++++ earth2studio/models/auto/package.py | 2 +- 5 files changed, 81 insertions(+), 4 deletions(-) diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index f1d53815..bf8b9a51 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -26,4 +26,4 @@ from .rx import CosineSolarZenith, LandSeaMask, SurfaceGeoPotential from .utils import datasource_to_file, fetch_data, prep_data_array from .wb2 import WB2ERA5, WB2Climatology, WB2ERA5_32x64, WB2ERA5_121x240 -from .xr import DataArrayFile, DataSetFile +from .xr import DataArrayFile, DataSetFile, DataArrayDirectory diff --git a/earth2studio/data/arco.py b/earth2studio/data/arco.py index 161c1744..7b623681 100644 --- a/earth2studio/data/arco.py +++ b/earth2studio/data/arco.py @@ -68,7 +68,9 @@ class ARCO: ARCO_LAT = np.linspace(90, -90, 721) ARCO_LON = np.linspace(0, 359.75, 1440) - def __init__(self, cache: bool = True, verbose: bool = True): + def __init__(self, cache: bool = True, + verbose: bool = True, + async_timeout: int = 600): self._cache = cache self._verbose = verbose @@ -90,7 +92,7 @@ def __init__(self, cache: bool = True, verbose: bool = True): "gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", fs ) self.zarr_group = zarr.open(fs_map, mode="r") - self.async_timeout = 600 + self.async_timeout = async_timeout self.async_process_limit = 4 def __call__( diff --git a/earth2studio/data/utils.py b/earth2studio/data/utils.py index ef00a5b2..0bfce71f 100644 --- a/earth2studio/data/utils.py +++ b/earth2studio/data/utils.py @@ -177,6 +177,7 @@ def datasource_to_file( lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]), backend: Literal["netcdf", "zarr"] = "netcdf", chunks: dict[str, int] = {"variable": 1}, + dtype: np.dtype | None = None, ) -> None: """Utility function that can be used for building a local data store needed for an inference request. This file can then be used with the @@ -221,6 +222,9 @@ def datasource_to_file( da = da.assign_coords(time=time) da = da.chunk(chunks=chunks) + if dtype is not None: + da = da.astype(dtype=dtype) + match backend: case "netcdf": da.to_netcdf(file_name) diff --git a/earth2studio/data/xr.py b/earth2studio/data/xr.py index 8546c29c..8021e62e 100644 --- a/earth2studio/data/xr.py +++ b/earth2studio/data/xr.py @@ -14,8 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from datetime import datetime from typing import Any +from numpy import ndarray +from pandas import to_datetime +from datetime import datetime import xarray as xr @@ -94,3 +98,70 @@ def __call__( Loaded data array """ return self.da.sel(time=time, variable=variable) + + +class DataArrayDirectory: + """A local xarray dataarray directory data source. This file should be compatable with + xarray. For example, a netCDF file. the structure of the directory should be like + path/to/monthly/files + |___2020 + | |___2020_01.nc + | |___2020_02.nc + | |___ ... + | + |___2021 + |___2021_01.nc + |___... + + Parameters + ---------- + file_path : str + Path to xarray data array compatible file. + """ + + def __init__(self, dir_path: str, **xr_args: Any): + self.dir_path = dir_path + self.das = {} + for yr in os.listdir(self.dir_path): + yr_dir = os.path.join(self.dir_path, yr) + if os.path.isdir(yr_dir): + self.das[yr] = {} + for fl in os.listdir(yr_dir): + pth = os.path.join(yr_dir, fl) + if os.path.isfile(pth): + try: + arr = xr.open_dataarray(pth, **xr_args) + except: + continue + mon = fl.split('.')[0].split('_')[-1] + self.das[yr][mon] = arr + + def __call__( + self, + time: datetime | list[datetime] | TimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Function to get data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Timestamps to return data for. + variable : str | list[str] | VariableArray + Strings or list of strings that refer to variables to return. + + Returns + ------- + xr.DataArray + Loaded data array + """ + if not (isinstance(time, list) or isinstance(time, ndarray)): + time = [time] + + arrs = [] + for tt in time: + yr = str(to_datetime(tt).year) + mon = str(to_datetime(tt).month).zfill(2) + arrs.append(self.das[yr][mon].sel(time=tt, variable=variable)) + + return xr.concat(arrs, dim='time') diff --git a/earth2studio/models/auto/package.py b/earth2studio/models/auto/package.py index 4cc604ee..63dcd5f6 100644 --- a/earth2studio/models/auto/package.py +++ b/earth2studio/models/auto/package.py @@ -236,7 +236,7 @@ def default_timeout(cls) -> int: int Time out in seconds """ - default_timeout = 300 + default_timeout = 3000 try: timeout = os.environ.get("EARTH2STUDIO_PACKAGE_TIMEOUT", default_timeout) default_timeout = int(timeout) From 5e62a3bcedddec719dfd6147810794e800469979 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:11:44 +0200 Subject: [PATCH 2/2] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4db988e9..e98b5036 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.4.0a0] - 2024-11-xx ### Added +- support for local directory of IC files ### Changed