Skip to content

Commit

Permalink
multivariate gard implementation (#229)
Browse files Browse the repository at this point in the history
* as committed here ran a test case for western US using three features to
predict precip - a couple remaining TODOs but functional at this point

* allow access to wind and standardize names/coordinates

* tweaks for multivariate run

* sample config for multivariate run

* update config files for multivariate and boolean in detrend

* variable switching

* param configs with wind

* split get_gcm by variable due to this issue with dask: pydata/xarray#6709

* switched order of postprocess and added method = nearest to .sel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: norland r hagen <norlandrhagen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 30, 2022
1 parent 127a719 commit 895fac3
Show file tree
Hide file tree
Showing 79 changed files with 967 additions and 136 deletions.
2 changes: 1 addition & 1 deletion cmip6_downscaling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"pangeo": {
"storage_prefix": "az://",
"storage_options": {'directory': './'},
'n_workers': 16,
'n_workers': 8,
'threads_per_worker': 1,
},
},
Expand Down
10 changes: 6 additions & 4 deletions cmip6_downscaling/data/cmip.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def postprocess(ds: xr.Dataset, to_standard_calendar: bool = True) -> xr.Dataset
# drop height variable
if 'height' in ds:
ds = ds.drop('height')

# squeeze length 1 dimensions
ds = ds.squeeze(drop=True)

Expand Down Expand Up @@ -121,14 +120,17 @@ def load_cmip(
grid_label=grid_labels,
variable_id=variable_ids,
)

keys = list(col_subset.keys())
if len(keys) != 1:
raise ValueError(f'intake-esm search returned {len(keys)}, expected exactly 1.')

ds = col_subset[keys[0]]().to_dask()

if time_slice:
ds = ds.sel(time=time_slice)

if 'plev' in ds.coords:
# select the 500 mb level for winds
ds = ds.sel(plev=50000.0, method='nearest').drop('plev')
ds = ds.pipe(postprocess)

# convert to mm/day - helpful to prevent rounding errors from very tiny numbers
Expand All @@ -147,7 +149,7 @@ def get_gcm(
table_id: str,
grid_label: str,
source_id: str,
variable: str,
variable: str | list[str],
time_slice: slice,
) -> xr.Dataset:
"""
Expand Down
16 changes: 14 additions & 2 deletions cmip6_downscaling/data/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,19 @@ def open_era5(variables: str | list[str], time_period: slice) -> xr.Dataset:
variables = [variables]

years = range(int(time_period.start), int(time_period.stop) + 1)

ds = xr.concat([cat.era5(year=year).to_dask()[variables] for year in years], dim='time')
wind_vars, non_wind_vars = [], []
for variable in variables:
if variable in ['ua', 'va']:
wind_vars.append(variable)
else:
non_wind_vars.append(variable)
ds = xr.concat([cat.era5(year=year).to_dask()[non_wind_vars] for year in years], dim='time')
for wind_var in wind_vars:
era5_winds = xr.open_zarr('az://training/ERA5_daily_winds').rename(
{'latitude': 'lat', 'longitude': 'lon'}
)
name_dict = {'ua': 'U', 'va': 'V'}
ds[wind_var] = era5_winds[name_dict[wind_var]].drop('level')

if 'pr' in variables:
# convert to mm/day - helpful to prevent rounding errors from very tiny numbers
Expand Down Expand Up @@ -73,6 +84,7 @@ def open_era5(variables: str | list[str], time_period: slice) -> xr.Dataset:
'nameCDM': 'Maximum_temperature_at_2_metres_since_previous_post-processing_surface_1_Hour_2',
'product_type': 'forecast',
}
# TODO adjust attrs of other variables

ds = lon_to_180(ds)

Expand Down
3 changes: 2 additions & 1 deletion cmip6_downscaling/methods/common/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def experiment(self) -> CMIP6Experiment:

@property
def run_id(self):
return f"{self.method}_{self.obs}_{self.model}_{self.member}_{self.scenario}_{self.variable}_{self.latmin}_{self.latmax}_{self.lonmin}_{self.lonmax}_{self.train_dates[0]}_{self.train_dates[1]}_{self.predict_dates[0]}_{self.predict_dates[1]}"
feature_string = '_'.join(self.features)
return f"{self.method}_{self.obs}_{self.model}_{self.member}_{self.scenario}_{self.variable}_{feature_string}_{self.latmin}_{self.latmax}_{self.lonmin}_{self.lonmax}_{self.train_dates[0]}_{self.train_dates[1]}_{self.predict_dates[0]}_{self.predict_dates[1]}"

@property
def run_id_hash(self):
Expand Down
67 changes: 33 additions & 34 deletions cmip6_downscaling/methods/common/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,22 @@ def get_obs(run_parameters: RunParameters) -> UPath:
UPath
Path to subset observation dataset.
"""

title = "obs ds: {obs}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{train_dates[0]}_{train_dates[1]}".format(
**asdict(run_parameters)
)
ds_hash = str_to_hash(
"{obs}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{train_dates[0]}_{train_dates[1]}".format(
**asdict(run_parameters)
)
feature_string = '_'.join(run_parameters.features)
frmt_str = "{obs}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{train_dates[0]}_{train_dates[1]}".format(
**asdict(run_parameters), feature_string=feature_string
)
title = f"obs ds: {frmt_str}"
ds_hash = str_to_hash(frmt_str)
target = intermediate_dir / 'get_obs' / ds_hash

if use_cache and zmetadata_exists(target):
print(f'found existing target: {target}')
return target

ds = open_era5(run_parameters.variable, run_parameters.train_period)

print(run_parameters)
ds = open_era5(run_parameters.features, run_parameters.train_period)
subset = subset_dataset(
ds,
run_parameters.variable,
run_parameters.features,
run_parameters.train_period.time_slice,
run_parameters.bbox,
chunking_schema={'time': 365, 'lat': 150, 'lon': 150},
Expand Down Expand Up @@ -125,8 +121,10 @@ def get_experiment(run_parameters: RunParameters, time_subset: str) -> UPath:
UPath to experiment dataset.
"""
time_period = getattr(run_parameters, time_subset)
frmt_str = "{model}_{member}_{scenario}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format(
time_period=time_period, **asdict(run_parameters)
feature_string = '_'.join(run_parameters.features)

frmt_str = "{model}_{member}_{scenario}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format(
time_period=time_period, **asdict(run_parameters), feature_string=feature_string
)

title = f"experiment ds: {frmt_str}"
Expand All @@ -138,29 +136,30 @@ def get_experiment(run_parameters: RunParameters, time_subset: str) -> UPath:
print(f'found existing target: {target}')
return target

ds = get_gcm(
scenario=run_parameters.scenario,
member_id=run_parameters.member,
table_id=run_parameters.table_id,
grid_label=run_parameters.grid_label,
source_id=run_parameters.model,
variable=run_parameters.variable,
time_slice=time_period.time_slice,
)

subset = subset_dataset(
ds, run_parameters.variable, time_period.time_slice, run_parameters.bbox
)
# The for loop is a workaround for github issue: https://github.com/pydata/xarray/issues/6709
mode = 'w'
for feature in run_parameters.features:
ds = get_gcm(
scenario=run_parameters.scenario,
member_id=run_parameters.member,
table_id=run_parameters.table_id,
grid_label=run_parameters.grid_label,
source_id=run_parameters.model,
variable=feature,
time_slice=time_period.time_slice,
)
subset = subset_dataset(ds, feature, time_period.time_slice, run_parameters.bbox)
# Note: dataset is chunked into time:365 chunks to standardize leap-year chunking.
subset = subset.chunk({'time': 365})
for key in subset.variables:
subset[key].encoding = {}

# Note: dataset is chunked into time:365 chunks to standardize leap-year chunking.
subset = subset.chunk({'time': 365})
for key in subset.variables:
subset[key].encoding = {}
subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version))

subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version))
subset = set_zarr_encoding(subset)
subset[[feature]].to_zarr(target, mode=mode)
mode = 'a'

subset = set_zarr_encoding(subset)
blocking_to_zarr(ds=subset, target=target, validate=True, write_empty_chunks=True)
return target


Expand Down
14 changes: 10 additions & 4 deletions cmip6_downscaling/methods/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
import re

import dask
import fsspec
import geopandas as gpd
import numpy as np
Expand All @@ -11,7 +12,7 @@
import xarray as xr
import zarr
from upath import UPath
from xarray_schema import DataArraySchema
from xarray_schema import DataArraySchema, DatasetSchema
from xarray_schema.base import SchemaError

from . import containers
Expand Down Expand Up @@ -80,6 +81,7 @@ def blocking_to_zarr(

for variable in ds.data_vars:
ds[variable].encoding['write_empty_chunks'] = True
ds = dask.optimize(ds)[0]
t = ds.to_zarr(target, mode='w', compute=False)
t.compute(retries=5)
zarr.consolidate_metadata(target)
Expand All @@ -90,7 +92,7 @@ def blocking_to_zarr(

def subset_dataset(
ds: xr.Dataset,
variable: str,
features: str | list,
time_period: slice,
bbox: containers.BBox,
chunking_schema: dict = None,
Expand Down Expand Up @@ -120,9 +122,13 @@ def subset_dataset(
lat=bbox.lat_slice,
)
if chunking_schema is not None:
target_schema = DataArraySchema(chunks=chunking_schema)
target_schema_array = DataArraySchema(chunks=chunking_schema)
schema_dict = {}
for feature in features:
schema_dict[feature] = target_schema_array
target_schema_dataset = DatasetSchema(schema_dict)
try:
target_schema.validate(subset_ds[variable])
target_schema_dataset.validate(subset_ds[features])
except SchemaError:
subset_ds = subset_ds.chunk(chunking_schema)

Expand Down
56 changes: 34 additions & 22 deletions cmip6_downscaling/methods/gard/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,35 +78,46 @@ def _fit_and_predict_wrapper(xtrain, ytrain, xpred, scrf, run_parameters, dim='t
# transformed gcm is the interpolated GCM for the prediction period transformed
# w.r.t. the interpolated obs used in the training (because that transformation
# is essentially part of the model)
bias_corrected_gcm_pred = (
bias_correct_gcm_by_method(
gcm_pred=xpred[run_parameters.variable],
method=run_parameters.bias_correction_method,
bc_kwargs=kws,
obs=xtrain[run_parameters.variable],
bias_corrected_gcm_pred = xr.Dataset()
for feature in run_parameters.features:
bias_corrected_gcm_pred[feature] = (
bias_correct_gcm_by_method(
gcm_pred=xpred[feature],
method=run_parameters.bias_correction_method,
bc_kwargs=kws[feature],
obs=xtrain[feature],
)
.sel(variable='variable_0')
.drop('variable')
)
.to_dataset(dim='variable')
.rename({'variable_0': run_parameters.variable})
)
# model definition
model = PointWiseDownscaler(
model=get_gard_model(run_parameters.model_type, run_parameters.model_params), dim=dim
)
# model fitting
if run_parameters.variable == 'pr':
model.fit(cbrt(xtrain[run_parameters.variable]), cbrt(ytrain[run_parameters.variable]))
out = model.predict(cbrt(bias_corrected_gcm_pred[run_parameters.variable])).to_dataset(
dim='variable'
)
# # TODO need to fix this to only transform some variables
if 'pr' in run_parameters.features:
bias_corrected_gcm_pred['pr'] = cbrt(bias_corrected_gcm_pred['pr'])
xtrain['pr'] = cbrt(xtrain['pr'])
if 'pr' == run_parameters.variable:
ytrain['pr'] = cbrt(ytrain['pr'])
# TODO: at this point there is negative precip in some chunks - why?
# <xarray.Dataset>
# Dimensions: (time: 23376, lat: 5, lon: 48)
# Coordinates:
# * lat (lat) float32 49.0 49.25 49.5 49.75 50.0
# * lon (lon) float32 -113.0 -112.8 -112.5 -112.2 ... -101.8 -101.5 -101.2
# * time (time) datetime64[ns] 1950-01-01 1950-01-02 ... 2013-12-31
# Data variables:
# pr (time, lat, lon) float32 0.4851 0.2508 0.1828 ... -0.5607 -0.5607
# tasmax (time, lat, lon) float32 270.3 270.3 270.1 ... 257.0 256.3 256.3
# tasmin (time, lat, lon) float32 261.5 261.3 261.1 ... 254.1 253.4 253.4
model.fit(xtrain[run_parameters.features], ytrain[run_parameters.variable])
out = model.predict(bias_corrected_gcm_pred[run_parameters.features]).to_dataset(dim='variable')
if 'pr' == run_parameters.variable:
out['pred'] = out['pred'] ** 3

else:
model.fit(xtrain[run_parameters.variable], ytrain[run_parameters.variable])
out = model.predict(bias_corrected_gcm_pred[run_parameters.variable]).to_dataset(
dim='variable'
)

# model prediction
# # model prediction
downscaled = add_random_effects(out, scrf.scrf, run_parameters)
return downscaled

Expand Down Expand Up @@ -258,7 +269,8 @@ def read_scrf(prediction_path: UPath, run_parameters: RunParameters):
scrf = scrf.assign_coords(
{'lat': prediction_ds.lat, 'lon': prediction_ds.lon, 'time': prediction_ds.time}
)

if (scrf.chunks['lon'][0] != 48) or (scrf.chunks['lat'][0] != 48):
scrf = scrf.chunk({'lon': 48, 'lat': 48, 'time': 3652})
scrf = dask.optimize(scrf)[0]
scrf = set_zarr_encoding(scrf)
blocking_to_zarr(ds=scrf, target=target, validate=True, write_empty_chunks=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"lonmin": "-180",
"lonmax": "180",
"bias_correction_method": "quantile_mapper",
"bias_correction_kwargs": { "detrend": false },
"bias_correction_kwargs": {
"pr": { "detrend": false },
"tasmin": { "detrend": true },
"tasmax": { "detrend": true },
"psl": { "detrend": false },
"ua": { "detrend": false },
"va": { "detrend": false }
},
"model_type": "PureRegression",
"model_params": { "thresh": 0 },
"train_dates": ["1981", "2010"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"lonmin": "-180",
"lonmax": "180",
"bias_correction_method": "quantile_mapper",
"bias_correction_kwargs": { "detrend": true },
"bias_correction_kwargs": {
"pr": { "detrend": false },
"tasmin": { "detrend": true },
"tasmax": { "detrend": true },
"psl": { "detrend": false },
"ua": { "detrend": false },
"va": { "detrend": false }
},
"model_type": "PureRegression",
"model_params": {},
"train_dates": ["1981", "2010"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"lonmin": "-180",
"lonmax": "180",
"bias_correction_method": "quantile_mapper",
"bias_correction_kwargs": { "detrend": true },
"bias_correction_kwargs": {
"pr": { "detrend": false },
"tasmin": { "detrend": true },
"tasmax": { "detrend": true },
"psl": { "detrend": false },
"ua": { "detrend": false },
"va": { "detrend": false }
},
"model_type": "PureRegression",
"model_params": {},
"train_dates": ["1981", "2010"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"lonmin": "-180",
"lonmax": "180",
"bias_correction_method": "quantile_mapper",
"bias_correction_kwargs": { "detrend": false },
"bias_correction_kwargs": {
"pr": { "detrend": false },
"tasmin": { "detrend": true },
"tasmax": { "detrend": true },
"psl": { "detrend": false },
"ua": { "detrend": false },
"va": { "detrend": false }
},
"model_type": "PureRegression",
"model_params": { "thresh": 0 },
"train_dates": ["1981", "2010"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"lonmin": "-180",
"lonmax": "180",
"bias_correction_method": "quantile_mapper",
"bias_correction_kwargs": { "detrend": true },
"bias_correction_kwargs": {
"pr": { "detrend": false },
"tasmin": { "detrend": true },
"tasmax": { "detrend": true },
"psl": { "detrend": false },
"ua": { "detrend": false },
"va": { "detrend": false }
},
"model_type": "PureRegression",
"model_params": {},
"train_dates": ["1981", "2010"],
Expand Down
Loading

1 comment on commit 895fac3

@vercel
Copy link

@vercel vercel bot commented on 895fac3 Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.