Skip to content

Commit

Permalink
Merge branch 'main' into pyspy-profiles
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Nov 5, 2024
2 parents e32b5cd + d1d776b commit 2073599
Show file tree
Hide file tree
Showing 20 changed files with 642 additions and 516 deletions.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies:
- bokeh ==3.5.1
- gilknocker ==0.4.1
- openssl >1.1.0g
- rasterio >=1.4.0
- rioxarray ==0.17.0
- h5netcdf ==1.3.0
- xesmf ==0.8.7
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def memray_profile(
tmp_path,
):
if not test_run_benchmark:
yield
yield contextlib.nullcontext
else:
memray_option = pytestconfig.getoption("--memray")

Expand Down Expand Up @@ -1006,7 +1006,7 @@ def performance_report(
tmp_path,
):
if not test_run_benchmark:
yield
yield contextlib.nullcontext
else:
if not pytestconfig.getoption("--performance-report"):
yield contextlib.nullcontext
Expand Down
58 changes: 7 additions & 51 deletions tests/geospatial/test_atmospheric_circulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr
from coiled.credentials.google import CoiledShippedCredentials

from tests.geospatial.workloads.atmospheric_circulation import atmospheric_circulation


def test_atmospheric_circulation(
gcs_url,
Expand All @@ -19,54 +20,9 @@ def test_atmospheric_circulation(
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr",
chunks={},
)
if scale == "small":
# 852.56 GiB (small)
time_range = slice("2020-01-01", "2020-02-01")
elif scale == "medium":
# 28.54 TiB (medium)
time_range = slice("2020-01-01", "2023-01-01")
else:
# 608.42 TiB (large)
time_range = slice(None)
ds = ds.sel(time=time_range)

ds = ds[
[
"u_component_of_wind",
"v_component_of_wind",
"temperature",
"vertical_velocity",
]
].rename(
{
"u_component_of_wind": "U",
"v_component_of_wind": "V",
"temperature": "T",
"vertical_velocity": "W",
}
result = atmospheric_circulation(
scale=scale,
storage_url=gcs_url,
storage_options={"token": CoiledShippedCredentials()},
)

zonal_means = ds.mean("longitude")
anomaly = ds - zonal_means

anomaly["uv"] = anomaly.U * anomaly.V
anomaly["vt"] = anomaly.V * anomaly.T
anomaly["uw"] = anomaly.U * anomaly.W

temdiags = zonal_means.merge(anomaly[["uv", "vt", "uw"]].mean("longitude"))

# This is incredibly slow, takes a while for flox to construct the graph
daily = temdiags.resample(time="D").mean()

# # Users often rework things via a rechunk to make this a blockwise problem
# daily = (
# temdiags.chunk(time=24)
# .resample(time="D")
# .mean()
# )

daily.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
result.compute()
164 changes: 11 additions & 153 deletions tests/geospatial/test_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,68 +9,9 @@
METHOD = "explicit"
"""

import numpy as np
import xarray as xr
from coiled.credentials.google import CoiledShippedCredentials


def compute_hourly_climatology(
ds: xr.Dataset,
) -> xr.Dataset:
hours = xr.DataArray(range(0, 24, 6), dims=["hour"])
window_weights = create_window_weights(61)
return xr.concat(
[compute_rolling_mean(select_hour(ds, hour), window_weights) for hour in hours],
dim=hours,
)


def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dataset:
window_size = len(window_weights)
half_window_size = window_size // 2 # For padding
ds = xr.concat(
[
replace_time_with_doy(ds.sel(time=str(y)))
for y in np.unique(ds.time.dt.year)
],
dim="year",
)
ds = ds.fillna(ds.sel(dayofyear=365))
ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap")
ds = ds.rolling(dayofyear=window_size, center=True).construct("window")
ds = ds.weighted(window_weights).mean(dim=("window", "year"))
return ds.isel(dayofyear=slice(half_window_size, -half_window_size))


def create_window_weights(window_size: int) -> xr.DataArray:
"""Create linearly decaying window weights."""
assert window_size % 2 == 1, "Window size must be odd."
half_window_size = window_size // 2
window_weights = np.concatenate(
[
np.linspace(0, 1, half_window_size + 1),
np.linspace(1, 0, half_window_size + 1)[1:],
]
)
window_weights = window_weights / window_weights.mean()
window_weights = xr.DataArray(window_weights, dims=["window"])
return window_weights


def replace_time_with_doy(ds: xr.Dataset) -> xr.Dataset:
"""Replace time coordinate with days of year."""
return ds.assign_coords({"time": ds.time.dt.dayofyear}).rename(
{"time": "dayofyear"}
)


def select_hour(ds: xr.Dataset, hour: int) -> xr.Dataset:
"""Select given hour of day from dataset."""
# Select hour
ds = ds.isel(time=ds.time.dt.hour == hour)
# Adjust time dimension
ds = ds.assign_coords({"time": ds.time.astype("datetime64[D]")})
return ds
from tests.geospatial.workloads.climatology import highlevel_api, rechunk_map_blocks


def test_rechunk_map_blocks(
Expand All @@ -90,49 +31,12 @@ def test_rechunk_map_blocks(
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr",
)

if scale == "small":
# 101.83 GiB (small)
time_range = slice("2020-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
elif scale == "medium":
# 2.12 TiB (medium)
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
else:
# 4.24 TiB (large)
# This currently doesn't complete successfully.
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature", "snow_depth"]
ds = ds[variables].sel(time=time_range)
original_chunks = ds.chunks

ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims])
pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"}

working = ds.chunk(pencil_chunks)
hours = xr.DataArray(range(0, 24, 6), dims=["hour"])
daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"])
template = (
working.isel(time=0)
.drop_vars("time")
.expand_dims(hour=hours, dayofyear=daysofyear)
.assign_coords(hour=hours, dayofyear=daysofyear)
result = rechunk_map_blocks(
scale=scale,
storage_url=gcs_url,
storage_options={"token": CoiledShippedCredentials()},
)
working = working.map_blocks(compute_hourly_climatology, template=template)

pancake_chunks = {
"hour": 1,
"dayofyear": 1,
"latitude": original_chunks["latitude"],
"longitude": original_chunks["longitude"],
}
result = working.chunk(pancake_chunks)
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
result.compute()


def test_highlevel_api(
Expand All @@ -153,55 +57,9 @@ def test_highlevel_api(
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr",
result = highlevel_api(
scale=scale,
storage_url=gcs_url,
storage_options={"token": CoiledShippedCredentials()},
)

if scale == "small":
# 101.83 GiB (small)
time_range = slice("2020-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
elif scale == "medium":
# 2.12 TiB (medium)
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
else:
# 4.24 TiB (large)
# This currently doesn't complete successfully.
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature", "snow_depth"]
ds = ds[variables].sel(time=time_range)
original_chunks = ds.chunks

# Drop all static variables
ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims])

# Split time dimension into three dimensions
ds["dayofyear"] = ds.time.dt.dayofyear
ds["hour"] = ds.time.dt.hour
ds["year"] = ds.time.dt.year
ds = ds.set_index(time=["year", "dayofyear", "hour"]).unstack()

# Fill empty values for non-leap years
ds = ds.ffill(dim="dayofyear", limit=1)

# Calculate climatology
window_size = 61
window_weights = create_window_weights(window_size)
half_window_size = window_size // 2
ds = ds.pad(pad_width={"dayofyear": half_window_size}, mode="wrap")
# FIXME: https://github.com/pydata/xarray/issues/9550
ds = ds.chunk(latitude=128, longitude=128)
ds = ds.rolling(dayofyear=window_size, center=True).construct("window")
ds = ds.weighted(window_weights).mean(dim=("window", "year"))
ds = ds.isel(dayofyear=slice(half_window_size, -half_window_size))

pancake_chunks = {
"hour": 1,
"dayofyear": 1,
"latitude": original_chunks["latitude"],
"longitude": original_chunks["longitude"],
}
result = ds.chunk(pancake_chunks)
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
result.compute()
83 changes: 3 additions & 80 deletions tests/geospatial/test_cloud_optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import xarray as xr
from tests.geospatial.workloads.cloud_optimize import cloud_optimize


def test_cloud_optimize(
Expand All @@ -19,82 +19,5 @@ def test_cloud_optimize(
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
# Define models and variables of interest
models = [
"ACCESS-CM2",
"ACCESS-ESM1-5",
"CMCC-ESM2",
"CNRM-CM6-1",
"CNRM-ESM2-1",
"CanESM5",
"EC-Earth3",
"EC-Earth3-Veg-LR",
"FGOALS-g3",
"GFDL-ESM4",
"GISS-E2-1-G",
"INM-CM4-8",
"INM-CM5-0",
"KACE-1-0-G",
"MIROC-ES2L",
"MPI-ESM1-2-HR",
"MPI-ESM1-2-LR",
"MRI-ESM2-0",
"NorESM2-LM",
"NorESM2-MM",
"TaiESM1",
"UKESM1-0-LL",
]
variables = [
"hurs",
"huss",
"pr",
"rlds",
"rsds",
"sfcWind",
"tas",
"tasmax",
"tasmin",
]

if scale == "small":
# 130 files (152.83 GiB). One model and one variable.
models = models[:1]
variables = variables[:1]
elif scale == "medium":
# 390 files. Two models and two variables.
# Currently fails after hitting 20 minute idle timeout
# sending large graph to the scheduler.
models = models[:2]
variables = variables[:2]
else:
# 11635 files. All models and variables.
pass

# Get netCDF data files -- see https://registry.opendata.aws/nex-gddp-cmip6
# for dataset details.
file_list = []
for model in models:
for variable in variables:
data_dir = f"s3://nex-gddp-cmip6/NEX-GDDP-CMIP6/{model}/historical/r1i1p1f1/{variable}/*.nc"
file_list += [f"s3://{path}" for path in s3.glob(data_dir)]
files = [s3.open(f) for f in file_list]
print(f"Processing {len(files)} NetCDF files")

# Load input NetCDF data files
# TODO: Reduce explicit settings once https://github.com/pydata/xarray/issues/8778 is completed.
ds = xr.open_mfdataset(
files,
engine="h5netcdf",
combine="nested",
concat_dim="time",
data_vars="minimal",
coords="minimal",
compat="override",
parallel=True,
)

# Rechunk from "pancake" to "pencil" format
ds = ds.chunk({"time": -1, "lon": "auto", "lat": "auto"})

# Write out to a Zar dataset
ds.to_zarr(s3_url)
result = cloud_optimize(scale, s3fs=s3, storage_url=s3_url)
result.compute()
Loading

0 comments on commit 2073599

Please sign in to comment.