From e75451510399c7c197dadf1f6953d7a293ec2d3e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 29 Oct 2024 10:42:21 +0100 Subject: [PATCH 1/7] Fix geospatial benchmarks without `--benchmark` flag (#1574) --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 684dc7be0c..5f4f781fc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -928,7 +928,7 @@ def performance_report( tmp_path, ): if not test_run_benchmark: - yield + yield contextlib.nullcontext else: if not pytestconfig.getoption("--performance-report"): yield contextlib.nullcontext From 51aa2e4fed7c2591b17f501ae866f751afed25d0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 29 Oct 2024 11:02:46 +0100 Subject: [PATCH 2/7] Fix geospatial benchmarks without --benchmark (#1575) --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5f4f781fc4..930670bed8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -888,7 +888,7 @@ def memray_profile( tmp_path, ): if not test_run_benchmark: - yield + yield contextlib.nullcontext else: memray_option = pytestconfig.getoption("--memray") From 9260c11ec4839964ac52bda83bfee2a72cbc761e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 29 Oct 2024 15:52:40 +0100 Subject: [PATCH 3/7] Use standard dataset in regrid (#1576) --- tests/geospatial/test_regrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/geospatial/test_regrid.py b/tests/geospatial/test_regrid.py index e4d5e1e131..dcb79cf14b 100644 --- a/tests/geospatial/test_regrid.py +++ b/tests/geospatial/test_regrid.py @@ -27,7 +27,7 @@ def test_xesmf( ) as client: # noqa: F841 # Load dataset ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr", + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", ) if scale == "small": From b6980245146ba13f9c5302fd8935f70754c6f923 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:57:25 +0100 Subject: [PATCH 4/7] Change rasterio opening call to new recommended method (#1580) --- ci/environment.yml | 1 + tests/geospatial/test_zonal_average.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ci/environment.yml b/ci/environment.yml index 979b5117ff..b7c87939cf 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -47,6 +47,7 @@ dependencies: - bokeh ==3.5.1 - gilknocker ==0.4.1 - openssl >1.1.0g + - rasterio >=1.4.0 - rioxarray ==0.17.0 - h5netcdf ==1.3.0 - xesmf ==0.8.7 diff --git a/tests/geospatial/test_zonal_average.py b/tests/geospatial/test_zonal_average.py index 3eec34d2d7..b74c93b574 100644 --- a/tests/geospatial/test_zonal_average.py +++ b/tests/geospatial/test_zonal_average.py @@ -37,8 +37,9 @@ def test_nwm( subset = ds.zwattablrt.sel(time=time_range) counties = rioxarray.open_rasterio( - s3.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), + "s3://nwm-250m-us-counties/Counties_on_250m_grid.tif", chunks="auto", + opener=s3.open, ).squeeze() # Remove any small floating point error in coordinate locations From 2a06ffda7ab2e57a23a485e699fa865b2ef477ce Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 4 Nov 2024 14:43:01 +0100 Subject: [PATCH 5/7] Separate geospatial benchmark execution from workload definition (#1579) --- .../test_atmospheric_circulation.py | 58 +----- tests/geospatial/test_climatology.py | 105 ++--------- tests/geospatial/test_cloud_optimize.py | 83 +-------- tests/geospatial/test_rechunking.py | 34 +--- tests/geospatial/test_regrid.py | 67 ------- tests/geospatial/test_regridding.py | 33 ++++ tests/geospatial/test_satellite_filtering.py | 105 +---------- tests/geospatial/test_submission.py | 0 tests/geospatial/test_zonal_average.py | 39 +--- tests/geospatial/utils.py | 13 ++ tests/geospatial/workloads/__init__.py | 0 .../workloads/atmospheric_circulation.py | 62 +++++++ tests/geospatial/workloads/climatology.py | 173 ++++++++++++++++++ tests/geospatial/workloads/cloud_optimize.py | 87 +++++++++ tests/geospatial/workloads/rechunking.py | 35 ++++ tests/geospatial/workloads/regridding.py | 53 ++++++ .../workloads/satellite_filtering.py | 110 +++++++++++ tests/geospatial/workloads/zonal_average.py | 44 +++++ 18 files changed, 646 insertions(+), 455 deletions(-) delete mode 100644 tests/geospatial/test_regrid.py create mode 100644 tests/geospatial/test_regridding.py create mode 100644 tests/geospatial/test_submission.py create mode 100644 tests/geospatial/utils.py create mode 100644 tests/geospatial/workloads/__init__.py create mode 100644 tests/geospatial/workloads/atmospheric_circulation.py create mode 100644 tests/geospatial/workloads/climatology.py create mode 100644 tests/geospatial/workloads/cloud_optimize.py create mode 100644 tests/geospatial/workloads/rechunking.py create mode 100644 tests/geospatial/workloads/regridding.py create mode 100644 tests/geospatial/workloads/satellite_filtering.py create mode 100644 tests/geospatial/workloads/zonal_average.py diff --git a/tests/geospatial/test_atmospheric_circulation.py b/tests/geospatial/test_atmospheric_circulation.py index d217b4659e..a5c8dd05d9 100644 --- a/tests/geospatial/test_atmospheric_circulation.py +++ b/tests/geospatial/test_atmospheric_circulation.py @@ -1,6 +1,7 @@ -import xarray as xr from coiled.credentials.google import CoiledShippedCredentials +from tests.geospatial.workloads.atmospheric_circulation import atmospheric_circulation + def test_atmospheric_circulation( gcs_url, @@ -19,54 +20,9 @@ def test_atmospheric_circulation( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", - chunks={}, - ) - if scale == "small": - # 852.56 GiB (small) - time_range = slice("2020-01-01", "2020-02-01") - elif scale == "medium": - # 28.54 TiB (medium) - time_range = slice("2020-01-01", "2023-01-01") - else: - # 608.42 TiB (large) - time_range = slice(None) - ds = ds.sel(time=time_range) - - ds = ds[ - [ - "u_component_of_wind", - "v_component_of_wind", - "temperature", - "vertical_velocity", - ] - ].rename( - { - "u_component_of_wind": "U", - "v_component_of_wind": "V", - "temperature": "T", - "vertical_velocity": "W", - } + result = atmospheric_circulation( + scale=scale, + storage_url=gcs_url, + storage_options={"token": CoiledShippedCredentials()}, ) - - zonal_means = ds.mean("longitude") - anomaly = ds - zonal_means - - anomaly["uv"] = anomaly.U * anomaly.V - anomaly["vt"] = anomaly.V * anomaly.T - anomaly["uw"] = anomaly.U * anomaly.W - - temdiags = zonal_means.merge(anomaly[["uv", "vt", "uw"]].mean("longitude")) - - # This is incredibly slow, takes a while for flox to construct the graph - daily = temdiags.resample(time="D").mean() - - # # Users often rework things via a rechunk to make this a blockwise problem - # daily = ( - # temdiags.chunk(time=24) - # .resample(time="D") - # .mean() - # ) - - daily.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + result.compute() diff --git a/tests/geospatial/test_climatology.py b/tests/geospatial/test_climatology.py index 792dbe05e7..c3144e0389 100644 --- a/tests/geospatial/test_climatology.py +++ b/tests/geospatial/test_climatology.py @@ -13,6 +13,8 @@ import xarray as xr from coiled.credentials.google import CoiledShippedCredentials +from tests.geospatial.workloads.climatology import highlevel_api, rechunk_map_blocks + def compute_hourly_climatology( ds: xr.Dataset, @@ -90,49 +92,12 @@ def test_rechunk_map_blocks( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - # Load dataset - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr", - ) - - if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature", "snow_depth"] - ds = ds[variables].sel(time=time_range) - original_chunks = ds.chunks - - ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims]) - pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"} - - working = ds.chunk(pencil_chunks) - hours = xr.DataArray(range(0, 24, 6), dims=["hour"]) - daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"]) - template = ( - working.isel(time=0) - .drop_vars("time") - .expand_dims(hour=hours, dayofyear=daysofyear) - .assign_coords(hour=hours, dayofyear=daysofyear) + result = rechunk_map_blocks( + scale=scale, + storage_url=gcs_url, + storage_options={"token": CoiledShippedCredentials()}, ) - working = working.map_blocks(compute_hourly_climatology, template=template) - - pancake_chunks = { - "hour": 1, - "dayofyear": 1, - "latitude": original_chunks["latitude"], - "longitude": original_chunks["longitude"], - } - result = working.chunk(pancake_chunks) - result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + result.compute() def test_highlevel_api( @@ -153,55 +118,9 @@ def test_highlevel_api( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - # Load dataset - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr", + result = highlevel_api( + scale=scale, + storage_url=gcs_url, + storage_options={"token": CoiledShippedCredentials()}, ) - - if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature", "snow_depth"] - ds = ds[variables].sel(time=time_range) - original_chunks = ds.chunks - - # Drop all static variables - ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims]) - - # Split time dimension into three dimensions - ds["dayofyear"] = ds.time.dt.dayofyear - ds["hour"] = ds.time.dt.hour - ds["year"] = ds.time.dt.year - ds = ds.set_index(time=["year", "dayofyear", "hour"]).unstack() - - # Fill empty values for non-leap years - ds = ds.ffill(dim="dayofyear", limit=1) - - # Calculate climatology - window_size = 61 - window_weights = create_window_weights(window_size) - half_window_size = window_size // 2 - ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap") - # FIXME: https://github.com/pydata/xarray/issues/9550 - ds = ds.chunk(latitude=128, longitude=128) - ds = ds.rolling(dayofyear=window_size, center=True).construct("window") - ds = ds.weighted(window_weights).mean(dim=("window", "year")) - ds = ds.isel(dayofyear=slice(half_window_size, -half_window_size)) - - pancake_chunks = { - "hour": 1, - "dayofyear": 1, - "latitude": original_chunks["latitude"], - "longitude": original_chunks["longitude"], - } - result = ds.chunk(pancake_chunks) - result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + result.compute() diff --git a/tests/geospatial/test_cloud_optimize.py b/tests/geospatial/test_cloud_optimize.py index f3a15f5e74..0cb7bbb9bd 100644 --- a/tests/geospatial/test_cloud_optimize.py +++ b/tests/geospatial/test_cloud_optimize.py @@ -1,4 +1,4 @@ -import xarray as xr +from tests.geospatial.workloads.cloud_optimize import cloud_optimize def test_cloud_optimize( @@ -19,82 +19,5 @@ def test_cloud_optimize( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - # Define models and variables of interest - models = [ - "ACCESS-CM2", - "ACCESS-ESM1-5", - "CMCC-ESM2", - "CNRM-CM6-1", - "CNRM-ESM2-1", - "CanESM5", - "EC-Earth3", - "EC-Earth3-Veg-LR", - "FGOALS-g3", - "GFDL-ESM4", - "GISS-E2-1-G", - "INM-CM4-8", - "INM-CM5-0", - "KACE-1-0-G", - "MIROC-ES2L", - "MPI-ESM1-2-HR", - "MPI-ESM1-2-LR", - "MRI-ESM2-0", - "NorESM2-LM", - "NorESM2-MM", - "TaiESM1", - "UKESM1-0-LL", - ] - variables = [ - "hurs", - "huss", - "pr", - "rlds", - "rsds", - "sfcWind", - "tas", - "tasmax", - "tasmin", - ] - - if scale == "small": - # 130 files (152.83 GiB). One model and one variable. - models = models[:1] - variables = variables[:1] - elif scale == "medium": - # 390 files. Two models and two variables. - # Currently fails after hitting 20 minute idle timeout - # sending large graph to the scheduler. - models = models[:2] - variables = variables[:2] - else: - # 11635 files. All models and variables. - pass - - # Get netCDF data files -- see https://registry.opendata.aws/nex-gddp-cmip6 - # for dataset details. - file_list = [] - for model in models: - for variable in variables: - data_dir = f"s3://nex-gddp-cmip6/NEX-GDDP-CMIP6/{model}/historical/r1i1p1f1/{variable}/*.nc" - file_list += [f"s3://{path}" for path in s3.glob(data_dir)] - files = [s3.open(f) for f in file_list] - print(f"Processing {len(files)} NetCDF files") - - # Load input NetCDF data files - # TODO: Reduce explicit settings once https://github.com/pydata/xarray/issues/8778 is completed. - ds = xr.open_mfdataset( - files, - engine="h5netcdf", - combine="nested", - concat_dim="time", - data_vars="minimal", - coords="minimal", - compat="override", - parallel=True, - ) - - # Rechunk from "pancake" to "pencil" format - ds = ds.chunk({"time": -1, "lon": "auto", "lat": "auto"}) - - # Write out to a Zar dataset - ds.to_zarr(s3_url) + result = cloud_optimize(scale, s3fs=s3, storage_url=s3_url) + result.compute() diff --git a/tests/geospatial/test_rechunking.py b/tests/geospatial/test_rechunking.py index 5044ff6c05..c6804672f4 100644 --- a/tests/geospatial/test_rechunking.py +++ b/tests/geospatial/test_rechunking.py @@ -1,6 +1,7 @@ -import xarray as xr from coiled.credentials.google import CoiledShippedCredentials +from tests.geospatial.workloads.rechunking import era5_rechunking + def test_era5_rechunking( gcs_url, @@ -19,28 +20,9 @@ def test_era5_rechunking( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - # Load dataset - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", - ).drop_encoding() - - if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2023-01-01") - variables = ["sea_surface_temperature"] - elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice(None) - variables = ["sea_surface_temperature"] - else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice(None) - variables = ["sea_surface_temperature", "snow_depth"] - subset = ds[variables].sel(time=time_range) - - # Rechunk - result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"}) - - # Write result to cloud storage - result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + result = era5_rechunking( + scale=scale, + storage_url=gcs_url, + storage_options={"token": CoiledShippedCredentials()}, + ) + result.compute() diff --git a/tests/geospatial/test_regrid.py b/tests/geospatial/test_regrid.py deleted file mode 100644 index dcb79cf14b..0000000000 --- a/tests/geospatial/test_regrid.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -import xesmf as xe -from coiled.credentials.google import CoiledShippedCredentials - - -@pytest.mark.parametrize("output_resolution", [1.5, 0.1]) -def test_xesmf( - gcs_url, - scale, - client_factory, - output_resolution, - cluster_kwargs={ - "workspace": "dask-benchmarks-gcp", - "region": "us-central1", - "wait_for_workers": True, - }, - scale_kwargs={ - "small": {"n_workers": 10}, - "medium": {"n_workers": 100}, - "large": {"n_workers": 100}, - }, -): - with client_factory( - **scale_kwargs[scale], **cluster_kwargs - ) as client: # noqa: F841 - # Load dataset - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", - ) - - if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] - else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature", "snow_depth"] - ds = ds[variables].sel(time=time_range) - - # 240x121 - out_grid = xr.Dataset( - { - "latitude": ( - ["latitude"], - np.arange(90, -90 - output_resolution, -output_resolution), - {"units": "degrees_north"}, - ), - "longitude": ( - ["longitude"], - np.arange(0, 360, output_resolution), - {"units": "degrees_east"}, - ), - } - ) - regridder = xe.Regridder(ds, out_grid, "bilinear", periodic=True) - regridded = regridder(ds, keep_attrs=True) - - result = regridded.chunk(time="auto") - result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) diff --git a/tests/geospatial/test_regridding.py b/tests/geospatial/test_regridding.py new file mode 100644 index 0000000000..18e7dff068 --- /dev/null +++ b/tests/geospatial/test_regridding.py @@ -0,0 +1,33 @@ +import pytest +from coiled.credentials.google import CoiledShippedCredentials + +from tests.geospatial.workloads.regridding import xesmf + + +@pytest.mark.parametrize("output_resolution", [1.5, 0.1]) +def test_xesmf( + gcs_url, + scale, + client_factory, + output_resolution, + cluster_kwargs={ + "workspace": "dask-benchmarks-gcp", + "region": "us-central1", + "wait_for_workers": True, + }, + scale_kwargs={ + "small": {"n_workers": 10}, + "medium": {"n_workers": 100}, + "large": {"n_workers": 100}, + }, +): + with client_factory( + **scale_kwargs[scale], **cluster_kwargs + ) as client: # noqa: F841 + result = xesmf( + scale=scale, + output_resolution=output_resolution, + storage_url=gcs_url, + storage_options={"token": CoiledShippedCredentials()}, + ) + result.compute() diff --git a/tests/geospatial/test_satellite_filtering.py b/tests/geospatial/test_satellite_filtering.py index b6049e95ee..4b65b40c9c 100644 --- a/tests/geospatial/test_satellite_filtering.py +++ b/tests/geospatial/test_satellite_filtering.py @@ -1,57 +1,6 @@ -import datetime import os -import fsspec -import geojson -import odc.stac -import planetary_computer -import pystac_client -import xarray as xr - - -def harmonize_to_old(data: xr.Dataset) -> xr.Dataset: - """ - Harmonize new Sentinel-2 data to the old baseline. - - Parameters - ---------- - data: - A Dataset with various bands as data variables and three dimensions: time, y, x - - Returns - ------- - harmonized: xarray.Dataset - A Dataset with all values harmonized to the old - processing baseline. - """ - cutoff = datetime.datetime(2022, 1, 25) - offset = 1000 - bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", - ] - - to_process = list(set(bands) & set(list(data.data_vars))) - old = data.sel(time=slice(cutoff))[to_process] - - new = data.sel(time=slice(cutoff, None)).drop_vars(to_process) - - new_harmonized = data.sel(time=slice(cutoff, None))[to_process].clip(offset) - new_harmonized -= offset - - new = xr.merge([new, new_harmonized]) - return xr.concat([old, new], dim="time") +from tests.geospatial.workloads.satellite_filtering import satellite_filtering def test_satellite_filtering( @@ -75,53 +24,5 @@ def test_satellite_filtering( }, **cluster_kwargs, ) as client: # noqa: F841 - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1", - modifier=planetary_computer.sign_inplace, - ) - - # GeoJSON for region of interest is from https://github.com/isellsoap/deutschlandGeoJSON/tree/main/1_deutschland - with fsspec.open( - "https://raw.githubusercontent.com/isellsoap/deutschlandGeoJSON/main/1_deutschland/3_mittel.geo.json" - ) as f: - gj = geojson.load(f) - - # Flatten MultiPolygon to single Polygon - coordinates = [] - for x in gj.features[0]["geometry"]["coordinates"]: - coordinates.extend(x) - area_of_interest = { - "type": "Polygon", - "coordinates": coordinates, - } - - # Get stack items - if scale == "small": - time_of_interest = "2024-01-01/2024-09-01" - else: - time_of_interest = "2015-01-01/2024-09-01" - - search = catalog.search( - collections=["sentinel-2-l2a"], - intersects=area_of_interest, - datetime=time_of_interest, - ) - items = search.item_collection() - - # Construct Xarray Dataset from stack items - ds = odc.stac.load( - items, - chunks={}, - patch_url=planetary_computer.sign, - resolution=40, - crs="EPSG:3857", - groupby="solar_day", - ) - # See https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a#Baseline-Change - ds = harmonize_to_old(ds) - - # Compute humidity index - humidity = (ds.B08 - ds.B11) / (ds.B08 + ds.B11) - - result = humidity.groupby("time.month").mean() - result.to_zarr(az_url) + result = satellite_filtering(scale=scale, storage_url=az_url) + result.compute() diff --git a/tests/geospatial/test_submission.py b/tests/geospatial/test_submission.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/geospatial/test_zonal_average.py b/tests/geospatial/test_zonal_average.py index b74c93b574..ffd23c44c4 100644 --- a/tests/geospatial/test_zonal_average.py +++ b/tests/geospatial/test_zonal_average.py @@ -2,10 +2,7 @@ This example was adapted from https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb """ -import flox.xarray -import numpy as np -import rioxarray -import xarray as xr +from tests.geospatial.workloads.zonal_average import nwm def test_nwm( @@ -24,35 +21,5 @@ def test_nwm( with client_factory( **scale_kwargs[scale], **cluster_kwargs ) as client: # noqa: F841 - ds = xr.open_zarr( - "s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", consolidated=True - ) - - if scale == "small": - # 6.03 TiB - time_range = slice("2020-01-01", "2020-12-31") - else: - # 252.30 TiB - time_range = slice("1979-02-01", "2020-12-31") - subset = ds.zwattablrt.sel(time=time_range) - - counties = rioxarray.open_rasterio( - "s3://nwm-250m-us-counties/Counties_on_250m_grid.tif", - chunks="auto", - opener=s3.open, - ).squeeze() - - # Remove any small floating point error in coordinate locations - _, counties_aligned = xr.align(subset, counties, join="override") - counties_aligned = counties_aligned.persist() - - county_id = np.unique(counties_aligned.data).compute() - county_id = county_id[county_id != 0] - county_mean = flox.xarray.xarray_reduce( - subset, - counties_aligned.rename("county"), - func="mean", - expected_groups=(county_id,), - ) - - county_mean.compute() + result = nwm(scale=scale, s3fs=s3) + result.compute() diff --git a/tests/geospatial/utils.py b/tests/geospatial/utils.py new file mode 100644 index 0000000000..a9bd9fe8aa --- /dev/null +++ b/tests/geospatial/utils.py @@ -0,0 +1,13 @@ +import xarray as xr + + +def load_era5() -> xr.Dataset: + return xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", + chunks={ + "longitude": "auto", + "latitude": "auto", + "levels": "auto", + "time": "auto", + }, + ) diff --git a/tests/geospatial/workloads/__init__.py b/tests/geospatial/workloads/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/geospatial/workloads/atmospheric_circulation.py b/tests/geospatial/workloads/atmospheric_circulation.py new file mode 100644 index 0000000000..61a308fc6b --- /dev/null +++ b/tests/geospatial/workloads/atmospheric_circulation.py @@ -0,0 +1,62 @@ +from typing import Any, Literal + +import xarray as xr +from dask.delayed import Delayed + + +def atmospheric_circulation( + scale: Literal["small", "medium", "large"], + storage_url: str, + storage_options: dict[str, Any], +) -> Delayed: + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", + chunks={}, + ) + if scale == "small": + # 852.56 GiB (small) + time_range = slice("2020-01-01", "2020-02-01") + elif scale == "medium": + # 28.54 TiB (medium) + time_range = slice("2020-01-01", "2023-01-01") + else: + # 608.42 TiB (large) + time_range = slice(None) + ds = ds.sel(time=time_range) + + ds = ds[ + [ + "u_component_of_wind", + "v_component_of_wind", + "temperature", + "vertical_velocity", + ] + ].rename( + { + "u_component_of_wind": "U", + "v_component_of_wind": "V", + "temperature": "T", + "vertical_velocity": "W", + } + ) + + zonal_means = ds.mean("longitude") + anomaly = ds - zonal_means + + anomaly["uv"] = anomaly.U * anomaly.V + anomaly["vt"] = anomaly.V * anomaly.T + anomaly["uw"] = anomaly.U * anomaly.W + + temdiags = zonal_means.merge(anomaly[["uv", "vt", "uw"]].mean("longitude")) + + # This is incredibly slow, takes a while for flox to construct the graph + daily = temdiags.resample(time="D").mean() + + # # Users often rework things via a rechunk to make this a blockwise problem + # daily = ( + # temdiags.chunk(time=24) + # .resample(time="D") + # .mean() + # ) + + return daily.to_zarr(storage_url, storage_options=storage_options, compute=False) diff --git a/tests/geospatial/workloads/climatology.py b/tests/geospatial/workloads/climatology.py new file mode 100644 index 0000000000..552b8f5cd6 --- /dev/null +++ b/tests/geospatial/workloads/climatology.py @@ -0,0 +1,173 @@ +from typing import Any, Literal + +import numpy as np +import xarray as xr +from dask.delayed import Delayed + + +def compute_hourly_climatology( + ds: xr.Dataset, +) -> xr.Dataset: + hours = xr.DataArray(range(0, 24, 6), dims=["hour"]) + window_weights = create_window_weights(61) + return xr.concat( + [compute_rolling_mean(select_hour(ds, hour), window_weights) for hour in hours], + dim=hours, + ) + + +def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dataset: + window_size = len(window_weights) + half_window_size = window_size // 2 # For padding + ds = xr.concat( + [ + replace_time_with_doy(ds.sel(time=str(y))) + for y in np.unique(ds.time.dt.year) + ], + dim="year", + ) + ds = ds.fillna(ds.sel(dayofyear=365)) + ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap") + ds = ds.rolling(dayofyear=window_size, center=True).construct("window") + ds = ds.weighted(window_weights).mean(dim=("window", "year")) + return ds.isel(dayofyear=slice(half_window_size, -half_window_size)) + + +def create_window_weights(window_size: int) -> xr.DataArray: + """Create linearly decaying window weights.""" + assert window_size % 2 == 1, "Window size must be odd." + half_window_size = window_size // 2 + window_weights = np.concatenate( + [ + np.linspace(0, 1, half_window_size + 1), + np.linspace(1, 0, half_window_size + 1)[1:], + ] + ) + window_weights = window_weights / window_weights.mean() + window_weights = xr.DataArray(window_weights, dims=["window"]) + return window_weights + + +def replace_time_with_doy(ds: xr.Dataset) -> xr.Dataset: + """Replace time coordinate with days of year.""" + return ds.assign_coords({"time": ds.time.dt.dayofyear}).rename( + {"time": "dayofyear"} + ) + + +def select_hour(ds: xr.Dataset, hour: int) -> xr.Dataset: + """Select given hour of day from dataset.""" + # Select hour + ds = ds.isel(time=ds.time.dt.hour == hour) + # Adjust time dimension + ds = ds.assign_coords({"time": ds.time.astype("datetime64[D]")}) + return ds + + +def rechunk_map_blocks( + scale: Literal["small", "medium", "large"], + storage_url: str, + storage_options: dict[str, Any], +) -> Delayed: + # Load dataset + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr", + ) + + if scale == "small": + # 101.83 GiB (small) + time_range = slice("2020-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + elif scale == "medium": + # 2.12 TiB (medium) + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + else: + # 4.24 TiB (large) + # This currently doesn't complete successfully. + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature", "snow_depth"] + ds = ds[variables].sel(time=time_range) + original_chunks = ds.chunks + + ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims]) + pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"} + + working = ds.chunk(pencil_chunks) + hours = xr.DataArray(range(0, 24, 6), dims=["hour"]) + daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"]) + template = ( + working.isel(time=0) + .drop_vars("time") + .expand_dims(hour=hours, dayofyear=daysofyear) + .assign_coords(hour=hours, dayofyear=daysofyear) + ) + working = working.map_blocks(compute_hourly_climatology, template=template) + + pancake_chunks = { + "hour": 1, + "dayofyear": 1, + "latitude": original_chunks["latitude"], + "longitude": original_chunks["longitude"], + } + result = working.chunk(pancake_chunks) + return result.to_zarr(storage_url, storage_options=storage_options, compute=False) + + +def highlevel_api( + scale: Literal["small", "medium", "large"], + storage_url: str, + storage_options: dict[str, Any], +) -> Delayed: + # Load dataset + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr", + ) + + if scale == "small": + # 101.83 GiB (small) + time_range = slice("2020-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + elif scale == "medium": + # 2.12 TiB (medium) + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + else: + # 4.24 TiB (large) + # This currently doesn't complete successfully. + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature", "snow_depth"] + ds = ds[variables].sel(time=time_range) + original_chunks = ds.chunks + + # Drop all static variables + ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims]) + + # Split time dimension into three dimensions + ds["dayofyear"] = ds.time.dt.dayofyear + ds["hour"] = ds.time.dt.hour + ds["year"] = ds.time.dt.year + ds = ds.set_index(time=["year", "dayofyear", "hour"]).unstack() + + # Fill empty values for non-leap years + ds = ds.ffill(dim="dayofyear", limit=1) + + # Calculate climatology + window_size = 61 + window_weights = create_window_weights(window_size) + half_window_size = window_size // 2 + ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap") + # FIXME: https://github.com/pydata/xarray/issues/9550 + ds = ds.chunk(latitude=128, longitude=128) + ds = ds.rolling(dayofyear=window_size, center=True).construct("window") + ds = ds.weighted(window_weights).mean(dim=("window", "year")) + ds = ds.isel(dayofyear=slice(half_window_size, -half_window_size)) + + pancake_chunks = { + "hour": 1, + "dayofyear": 1, + "latitude": original_chunks["latitude"], + "longitude": original_chunks["longitude"], + } + result = ds.chunk(pancake_chunks) + return result.to_zarr(storage_url, storage_options=storage_options, compute=False) diff --git a/tests/geospatial/workloads/cloud_optimize.py b/tests/geospatial/workloads/cloud_optimize.py new file mode 100644 index 0000000000..93b08bb611 --- /dev/null +++ b/tests/geospatial/workloads/cloud_optimize.py @@ -0,0 +1,87 @@ +from typing import Literal + +import xarray as xr +from s3fs import S3FileSystem + + +def cloud_optimize( + scale: Literal["small", "medium", "large"], s3fs: S3FileSystem, storage_url: str +): + models = [ + "ACCESS-CM2", + "ACCESS-ESM1-5", + "CMCC-ESM2", + "CNRM-CM6-1", + "CNRM-ESM2-1", + "CanESM5", + "EC-Earth3", + "EC-Earth3-Veg-LR", + "FGOALS-g3", + "GFDL-ESM4", + "GISS-E2-1-G", + "INM-CM4-8", + "INM-CM5-0", + "KACE-1-0-G", + "MIROC-ES2L", + "MPI-ESM1-2-HR", + "MPI-ESM1-2-LR", + "MRI-ESM2-0", + "NorESM2-LM", + "NorESM2-MM", + "TaiESM1", + "UKESM1-0-LL", + ] + variables = [ + "hurs", + "huss", + "pr", + "rlds", + "rsds", + "sfcWind", + "tas", + "tasmax", + "tasmin", + ] + + if scale == "small": + # 130 files (152.83 GiB). One model and one variable. + models = models[:1] + variables = variables[:1] + elif scale == "medium": + # 390 files. Two models and two variables. + # Currently fails after hitting 20 minute idle timeout + # sending large graph to the scheduler. + models = models[:2] + variables = variables[:2] + else: + # 11635 files. All models and variables. + pass + + # Get netCDF data files -- see https://registry.opendata.aws/nex-gddp-cmip6 + # for dataset details. + file_list = [] + for model in models: + for variable in variables: + data_dir = f"s3://nex-gddp-cmip6/NEX-GDDP-CMIP6/{model}/historical/r1i1p1f1/{variable}/*.nc" + file_list += [f"s3://{path}" for path in s3fs.glob(data_dir)] + files = [s3fs.open(f) for f in file_list] + print(f"Processing {len(files)} NetCDF files") + + # Load input NetCDF data files + # TODO: Reduce explicit settings once https://github.com/pydata/xarray/issues/8778 is completed. + ds = xr.open_mfdataset( + files, + engine="h5netcdf", + combine="nested", + concat_dim="time", + data_vars="minimal", + coords="minimal", + compat="override", + parallel=True, + ) + + # Rechunk from "pancake" to "pencil" format + ds = ds.chunk({"time": -1, "lon": "auto", "lat": "auto"}) + + # Write out to a Zar dataset + return ds.to_zarr(storage_url, compute=False) diff --git a/tests/geospatial/workloads/rechunking.py b/tests/geospatial/workloads/rechunking.py new file mode 100644 index 0000000000..9a20994aee --- /dev/null +++ b/tests/geospatial/workloads/rechunking.py @@ -0,0 +1,35 @@ +from typing import Any, Literal + +import xarray as xr +from dask.delayed import Delayed + + +def era5_rechunking( + scale: Literal["small", "medium", "large"], + storage_url: str, + storage_options: dict[str, Any], +) -> Delayed: + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", + ).drop_encoding() + + if scale == "small": + # 101.83 GiB (small) + time_range = slice("2020-01-01", "2023-01-01") + variables = ["sea_surface_temperature"] + elif scale == "medium": + # 2.12 TiB (medium) + time_range = slice(None) + variables = ["sea_surface_temperature"] + else: + # 4.24 TiB (large) + # This currently doesn't complete successfully. + time_range = slice(None) + variables = ["sea_surface_temperature", "snow_depth"] + subset = ds[variables].sel(time=time_range) + + # Rechunk + result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"}) + + # Write result to cloud storage + return result.to_zarr(storage_url, storage_options=storage_options, compute=False) diff --git a/tests/geospatial/workloads/regridding.py b/tests/geospatial/workloads/regridding.py new file mode 100644 index 0000000000..9b63d02390 --- /dev/null +++ b/tests/geospatial/workloads/regridding.py @@ -0,0 +1,53 @@ +from typing import Any, Literal + +import numpy as np +import xarray as xr +import xesmf as xe +from dask.delayed import Delayed + + +def xesmf( + scale: Literal["small", "medium", "large"], + output_resolution: float, + storage_url: str, + storage_options: dict[str, Any], +) -> Delayed: + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", + ) + + if scale == "small": + # 101.83 GiB (small) + time_range = slice("2020-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + elif scale == "medium": + # 2.12 TiB (medium) + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature"] + else: + # 4.24 TiB (large) + # This currently doesn't complete successfully. + time_range = slice("1959-01-01", "2022-12-31") + variables = ["sea_surface_temperature", "snow_depth"] + ds = ds[variables].sel(time=time_range) + + # 240x121 + out_grid = xr.Dataset( + { + "latitude": ( + ["latitude"], + np.arange(90, -90 - output_resolution, -output_resolution), + {"units": "degrees_north"}, + ), + "longitude": ( + ["longitude"], + np.arange(0, 360, output_resolution), + {"units": "degrees_east"}, + ), + } + ) + regridder = xe.Regridder(ds, out_grid, "bilinear", periodic=True) + regridded = regridder(ds, keep_attrs=True) + + result = regridded.chunk(time="auto") + return result.to_zarr(storage_url, storage_options=storage_options, compute=False) diff --git a/tests/geospatial/workloads/satellite_filtering.py b/tests/geospatial/workloads/satellite_filtering.py new file mode 100644 index 0000000000..6e6dc6dfa7 --- /dev/null +++ b/tests/geospatial/workloads/satellite_filtering.py @@ -0,0 +1,110 @@ +import datetime +from typing import Literal + +import fsspec +import geojson +import odc.stac +import planetary_computer +import pystac_client +import xarray as xr + + +def harmonize_to_old(data: xr.Dataset) -> xr.Dataset: + """ + Harmonize new Sentinel-2 data to the old baseline. + + Parameters + ---------- + data: + A Dataset with various bands as data variables and three dimensions: time, y, x + + Returns + ------- + harmonized: xarray.Dataset + A Dataset with all values harmonized to the old + processing baseline. + """ + cutoff = datetime.datetime(2022, 1, 25) + offset = 1000 + bands = [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ] + + to_process = list(set(bands) & set(list(data.data_vars))) + old = data.sel(time=slice(cutoff))[to_process] + + new = data.sel(time=slice(cutoff, None)).drop_vars(to_process) + + new_harmonized = data.sel(time=slice(cutoff, None))[to_process].clip(offset) + new_harmonized -= offset + + new = xr.merge([new, new_harmonized]) + return xr.concat([old, new], dim="time") + + +def satellite_filtering( + scale: Literal["small", "medium", "large"], + storage_url: str, +): + catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=planetary_computer.sign_inplace, + ) + + # GeoJSON for region of interest is from https://github.com/isellsoap/deutschlandGeoJSON/tree/main/1_deutschland + with fsspec.open( + "https://raw.githubusercontent.com/isellsoap/deutschlandGeoJSON/main/1_deutschland/3_mittel.geo.json" + ) as f: + gj = geojson.load(f) + + # Flatten MultiPolygon to single Polygon + coordinates = [] + for x in gj.features[0]["geometry"]["coordinates"]: + coordinates.extend(x) + area_of_interest = { + "type": "Polygon", + "coordinates": coordinates, + } + + # Get stack items + if scale == "small": + time_of_interest = "2024-01-01/2024-09-01" + else: + time_of_interest = "2015-01-01/2024-09-01" + + search = catalog.search( + collections=["sentinel-2-l2a"], + intersects=area_of_interest, + datetime=time_of_interest, + ) + items = search.item_collection() + + # Construct Xarray Dataset from stack items + ds = odc.stac.load( + items, + chunks={}, + patch_url=planetary_computer.sign, + resolution=40, + crs="EPSG:3857", + groupby="solar_day", + ) + # See https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a#Baseline-Change + ds = harmonize_to_old(ds) + + # Compute humidity index + humidity = (ds.B08 - ds.B11) / (ds.B08 + ds.B11) + + result = humidity.groupby("time.month").mean() + return result.to_zarr(storage_url, compute=False) diff --git a/tests/geospatial/workloads/zonal_average.py b/tests/geospatial/workloads/zonal_average.py new file mode 100644 index 0000000000..390890d7db --- /dev/null +++ b/tests/geospatial/workloads/zonal_average.py @@ -0,0 +1,44 @@ +from typing import Literal + +import flox +import numpy as np +import rioxarray +import xarray as xr +from s3fs import S3FileSystem + + +def nwm( + scale: Literal["small", "medium", "large"], + s3fs: S3FileSystem, +) -> xr.DataArray: + ds = xr.open_zarr( + "s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", consolidated=True + ) + + if scale == "small": + # 6.03 TiB + time_range = slice("2020-01-01", "2020-12-31") + else: + # 252.30 TiB + time_range = slice("1979-02-01", "2020-12-31") + subset = ds.zwattablrt.sel(time=time_range) + + counties = rioxarray.open_rasterio( + "s3://nwm-250m-us-counties/Counties_on_250m_grid.tif", + chunks="auto", + opener=s3fs.open, + ).squeeze() + + # Remove any small floating point error in coordinate locations + _, counties_aligned = xr.align(subset, counties, join="override") + counties_aligned = counties_aligned.persist() + + county_id = np.unique(counties_aligned.data).compute() + county_id = county_id[county_id != 0] + county_mean = flox.xarray.xarray_reduce( + subset, + counties_aligned.rename("county"), + func="mean", + expected_groups=(county_id,), + ) + return county_mean From 20a27703cfb001eb2fb7a9ed22ace122245cf6bd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 5 Nov 2024 10:35:16 +0100 Subject: [PATCH 6/7] Remove dead code (#1582) --- tests/geospatial/test_climatology.py | 61 ---------------------------- 1 file changed, 61 deletions(-) diff --git a/tests/geospatial/test_climatology.py b/tests/geospatial/test_climatology.py index c3144e0389..201824b6e5 100644 --- a/tests/geospatial/test_climatology.py +++ b/tests/geospatial/test_climatology.py @@ -9,72 +9,11 @@ METHOD = "explicit" """ -import numpy as np -import xarray as xr from coiled.credentials.google import CoiledShippedCredentials from tests.geospatial.workloads.climatology import highlevel_api, rechunk_map_blocks -def compute_hourly_climatology( - ds: xr.Dataset, -) -> xr.Dataset: - hours = xr.DataArray(range(0, 24, 6), dims=["hour"]) - window_weights = create_window_weights(61) - return xr.concat( - [compute_rolling_mean(select_hour(ds, hour), window_weights) for hour in hours], - dim=hours, - ) - - -def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dataset: - window_size = len(window_weights) - half_window_size = window_size // 2 # For padding - ds = xr.concat( - [ - replace_time_with_doy(ds.sel(time=str(y))) - for y in np.unique(ds.time.dt.year) - ], - dim="year", - ) - ds = ds.fillna(ds.sel(dayofyear=365)) - ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap") - ds = ds.rolling(dayofyear=window_size, center=True).construct("window") - ds = ds.weighted(window_weights).mean(dim=("window", "year")) - return ds.isel(dayofyear=slice(half_window_size, -half_window_size)) - - -def create_window_weights(window_size: int) -> xr.DataArray: - """Create linearly decaying window weights.""" - assert window_size % 2 == 1, "Window size must be odd." - half_window_size = window_size // 2 - window_weights = np.concatenate( - [ - np.linspace(0, 1, half_window_size + 1), - np.linspace(1, 0, half_window_size + 1)[1:], - ] - ) - window_weights = window_weights / window_weights.mean() - window_weights = xr.DataArray(window_weights, dims=["window"]) - return window_weights - - -def replace_time_with_doy(ds: xr.Dataset) -> xr.Dataset: - """Replace time coordinate with days of year.""" - return ds.assign_coords({"time": ds.time.dt.dayofyear}).rename( - {"time": "dayofyear"} - ) - - -def select_hour(ds: xr.Dataset, hour: int) -> xr.Dataset: - """Select given hour of day from dataset.""" - # Select hour - ds = ds.isel(time=ds.time.dt.hour == hour) - # Adjust time dimension - ds = ds.assign_coords({"time": ds.time.astype("datetime64[D]")}) - return ds - - def test_rechunk_map_blocks( gcs_url, scale, From d1d776b326bbc7ed3a34dde75292c9649ac7b64f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 5 Nov 2024 15:18:47 +0100 Subject: [PATCH 7/7] Recalibrate scales of regridding benchmark (#1581) --- tests/geospatial/test_regridding.py | 8 ++------ tests/geospatial/workloads/regridding.py | 26 +++++++++++------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/tests/geospatial/test_regridding.py b/tests/geospatial/test_regridding.py index 18e7dff068..ff8df979fc 100644 --- a/tests/geospatial/test_regridding.py +++ b/tests/geospatial/test_regridding.py @@ -1,15 +1,12 @@ -import pytest from coiled.credentials.google import CoiledShippedCredentials from tests.geospatial.workloads.regridding import xesmf -@pytest.mark.parametrize("output_resolution", [1.5, 0.1]) def test_xesmf( gcs_url, scale, client_factory, - output_resolution, cluster_kwargs={ "workspace": "dask-benchmarks-gcp", "region": "us-central1", @@ -17,8 +14,8 @@ def test_xesmf( }, scale_kwargs={ "small": {"n_workers": 10}, - "medium": {"n_workers": 100}, - "large": {"n_workers": 100}, + "medium": {"n_workers": 10}, + "large": {"n_workers": 10}, }, ): with client_factory( @@ -26,7 +23,6 @@ def test_xesmf( ) as client: # noqa: F841 result = xesmf( scale=scale, - output_resolution=output_resolution, storage_url=gcs_url, storage_options={"token": CoiledShippedCredentials()}, ) diff --git a/tests/geospatial/workloads/regridding.py b/tests/geospatial/workloads/regridding.py index 9b63d02390..307c37f53d 100644 --- a/tests/geospatial/workloads/regridding.py +++ b/tests/geospatial/workloads/regridding.py @@ -8,30 +8,28 @@ def xesmf( scale: Literal["small", "medium", "large"], - output_resolution: float, storage_url: str, storage_options: dict[str, Any], ) -> Delayed: ds = xr.open_zarr( "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", ) - + # Fixed time range and variable as the interesting part of this benchmark scales with the + # regridding matrix + ds = ds[["sea_surface_temperature"]].sel(time=slice("2020-01-01", "2021-12-31")) if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] + # Regridding from a resolution of 0.25 degress to 1 degrees + # results in 4 MiB weight matrix + output_resolution = 1 elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature"] + # Regridding from a resolution of 0.25 degrees to 0.2 degrees + # results in 100 MiB weight matrix + output_resolution = 0.2 else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice("1959-01-01", "2022-12-31") - variables = ["sea_surface_temperature", "snow_depth"] - ds = ds[variables].sel(time=time_range) + # Regridding from a resolution of 0.25 degrees to 0.05 degrees + # results in 1.55 GiB weight matrix + output_resolution = 0.05 - # 240x121 out_grid = xr.Dataset( { "latitude": (