Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maca update #220

Merged
merged 17 commits into from
Jun 30, 2022
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -15,14 +15,14 @@ repos:
- id: mixed-line-ending

- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
rev: v2.34.0
hooks:
- id: pyupgrade
args:
- '--py38-plus'

- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.6.0
hooks:
- id: black
- id: black-jupyter
Expand All @@ -47,9 +47,10 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.6.2
rev: v2.7.1
hooks:
- id: prettier
language_version: system
additional_dependencies:
- prettier
- '@carbonplan/prettier'
8 changes: 7 additions & 1 deletion cmip6_downscaling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
},
'gcm_obs_weights': {'uri': 'az://static/xesmf_weights/gcm_obs/weights.csv'},
},
'run_options': {'runtime': "pangeo", 'use_cache': False, 'generate_pyramids': False},
'run_options': {
'runtime': "pangeo",
'use_cache': True,
'generate_pyramids': False,
'construct_analogs': True,
'combine_regions': False,
},
"runtime": {
"cloud": {
"storage_prefix": "az://",
Expand Down
13 changes: 6 additions & 7 deletions cmip6_downscaling/data/cmip.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,13 @@ def get_gcm(
variable_ids=variable,
time_slice=time_slice,
)
if float(time_slice.stop) < 2015:
# you're working with historical data
ds_gcm = load_cmip(activity_ids='CMIP', experiment_ids='historical', **kws)
elif float(time_slice.start) > 2014:
# you're working with future data
ds_gcm = load_cmip(activity_ids='ScenarioMIP', experiment_ids=scenario, **kws)

if scenario == 'historical':
activity_id = 'CMIP'
else:
raise ValueError(f'time slice {time_slice} not supported')
activity_id = 'ScenarioMIP'

ds_gcm = load_cmip(activity_ids=activity_id, experiment_ids=scenario, **kws)
ds_gcm = ds_gcm.reindex(time=sorted(ds_gcm.time.values))

return ds_gcm
12 changes: 7 additions & 5 deletions cmip6_downscaling/methods/common/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ class RunParameters:
table_id: str
scenario: str
variable: str
features: list
latmin: float
latmax: float
lonmin: float
lonmax: float
train_dates: list
predict_dates: list
bias_correction_method: str
bias_correction_kwargs: dict
model_type: str
model_params: dict
features: list = None # gard only
bias_correction_method: str = None # gard only
bias_correction_kwargs: dict = None # gard only
model_type: str = None # gard only
model_params: dict = None # gard only
year_rolling_window: int = None # maca only
day_rolling_window: int = None # maca only

@property
def bbox(self) -> BBox:
Expand Down
101 changes: 63 additions & 38 deletions cmip6_downscaling/methods/common/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import datetime
import functools
import json
import os
import warnings
Expand All @@ -28,15 +27,15 @@
from ...data.cmip import get_gcm
from ...data.observations import open_era5
from ...utils import str_to_hash
from .containers import RunParameters
from .containers import RunParameters, TimePeriod
from .utils import (
blocking_to_zarr,
calc_auspicious_chunks_dict,
is_cached,
resample_wrapper,
set_zarr_encoding,
subset_dataset,
validate_zarr_store,
zmetadata_exists,
)

xr.set_options(keep_attrs=True)
Expand All @@ -52,8 +51,6 @@
results_dir = UPath(config.get("storage.results.uri")) / version
use_cache = config.get('run_options.use_cache')

is_cached = functools.partial(validate_zarr_store, raise_on_error=False)


@task(log_stdout=True)
def make_run_parameters(**kwargs) -> RunParameters:
Expand Down Expand Up @@ -83,7 +80,7 @@ def get_obs(run_parameters: RunParameters) -> UPath:
ds_hash = str_to_hash(frmt_str)
target = intermediate_dir / 'get_obs' / ds_hash

if use_cache and zmetadata_exists(target):
if use_cache and is_cached(target):
print(f'found existing target: {target}')
return target
print(run_parameters)
Expand All @@ -100,7 +97,6 @@ def get_obs(run_parameters: RunParameters) -> UPath:
subset[key].encoding = {}

subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version))
print(f'writing {target}', subset)
subset = set_zarr_encoding(subset)
blocking_to_zarr(ds=subset, target=target, validate=True, write_empty_chunks=True)

Expand All @@ -116,49 +112,75 @@ def get_experiment(run_parameters: RunParameters, time_subset: str) -> UPath:
run_parameters : RunParameters
RunParameter dataclass defined in common/conatiners.py. Constructed from prefect parameters.
time_subset : str
String describing time subset request. Either 'train_period' or 'predict_period'
String describing time subset request. Either 'train_period', 'predict_period', or 'both'.

Returns
-------
UPath
UPath to experiment dataset.
"""
time_period = getattr(run_parameters, time_subset)
feature_string = '_'.join(run_parameters.features)

frmt_str = "{model}_{member}_{scenario}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format(
time_period=time_period, **asdict(run_parameters), feature_string=feature_string
)
if time_subset == 'both':
time_period = TimePeriod(
start=str(
min(
int(run_parameters.train_period.start), int(run_parameters.predict_period.start)
)
),
stop=str(
max(int(run_parameters.train_period.stop), int(run_parameters.predict_period.stop))
),
)
else:
time_period = getattr(run_parameters, time_subset)

features = getattr(run_parameters, 'features')
if features:
feature_string = '_'.join(features)
frmt_str = "{model}_{member}_{scenario}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format(
time_period=time_period, **asdict(run_parameters), feature_string=feature_string
)

else:
frmt_str = "{model}_{member}_{scenario}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format(
time_period=time_period, **asdict(run_parameters)
)

if int(time_period.start) < 2015 and run_parameters.scenario != 'historical':
scenarios = ['historical', run_parameters.scenario]
else:
scenarios = [run_parameters.scenario]

title = f"experiment ds: {frmt_str}"
ds_hash = str_to_hash(frmt_str)
target = intermediate_dir / 'get_experiment' / ds_hash

print(target)
if use_cache and is_cached(target):
print(f'found existing target: {target}')
return target

# The for loop is a workaround for github issue: https://github.com/pydata/xarray/issues/6709
mode = 'w'
for feature in run_parameters.features:
ds = get_gcm(
scenario=run_parameters.scenario,
member_id=run_parameters.member,
table_id=run_parameters.table_id,
grid_label=run_parameters.grid_label,
source_id=run_parameters.model,
variable=feature,
time_slice=time_period.time_slice,
)
ds_list = []
for s in scenarios:
ds_list.append(
get_gcm(
scenario=s,
member_id=run_parameters.member,
table_id=run_parameters.table_id,
grid_label=run_parameters.grid_label,
source_id=run_parameters.model,
variable=feature,
time_slice=time_period.time_slice,
)
)
ds = xr.concat(ds_list, dim='time')
subset = subset_dataset(ds, feature, time_period.time_slice, run_parameters.bbox)
# Note: dataset is chunked into time:365 chunks to standardize leap-year chunking.
subset = subset.chunk({'time': 365})
for key in subset.variables:
subset[key].encoding = {}

subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version))

subset = set_zarr_encoding(subset)
subset[[feature]].to_zarr(target, mode=mode)
mode = 'a'
Expand Down Expand Up @@ -206,8 +228,7 @@ def rechunk(
task_hash = str_to_hash(str(path) + pattern_string + str(template) + max_mem)
target = intermediate_dir / 'rechunk' / task_hash
path_tmp = scratch_dir / 'rechunk' / task_hash
print(f'writing rechunked dataset to {target}')
print(target)

target_store = fsspec.get_mapper(str(target))
temp_store = fsspec.get_mapper(str(path_tmp))

Expand Down Expand Up @@ -243,7 +264,6 @@ def rechunk(
chunk_dims = config.get(f"chunk_dims.{pattern}")
for dim in chunk_def:
if dim not in chunk_dims:
print('correcting dim')
# override the chunksize of those unchunked dimensions to be the complete length (like passing chunksize=-1
chunk_def[dim] = len(ds[dim])
elif pattern is not None:
Expand Down Expand Up @@ -316,7 +336,7 @@ def time_summary(ds_path: UPath, freq: str) -> UPath:

ds_hash = str_to_hash(str(ds_path) + freq)
target = results_dir / 'time_summary' / ds_hash
print(target)

if use_cache and is_cached(target):
print(f'found existing target: {target}')
return target
Expand Down Expand Up @@ -362,7 +382,6 @@ def get_weights(*, run_parameters, direction, regrid_method="bilinear"):
.iloc[0]
.path
)
print(path)
return path


Expand All @@ -385,16 +404,19 @@ def get_pyramid_weights(*, run_parameters, levels: int, regrid_method: str = "bi
Path to pyramid weights file.
"""
weights = pd.read_csv(config.get('weights.downscaled_pyramid_weights.uri'))
print(weights)
path = (
weights[(weights.regrid_method == regrid_method) & (weights.levels == levels)].iloc[0].path
)
print(path)
return path


@task(log_stdout=True)
def regrid(source_path: UPath, target_grid_path: UPath, weights_path: UPath = None) -> UPath:
def regrid(
source_path: UPath,
target_grid_path: UPath,
weights_path: UPath = None,
pre_chunk_def: dict = None,
) -> UPath:
"""Task to regrid a dataset to target grid.

Parameters
Expand All @@ -416,18 +438,21 @@ def regrid(source_path: UPath, target_grid_path: UPath, weights_path: UPath = No

ds_hash = str_to_hash(str(source_path) + str(target_grid_path))
target = intermediate_dir / 'regrid' / ds_hash
print(target)

if use_cache and zmetadata_exists(target):
if use_cache and is_cached(target):
print(f'found existing target: {target}')
return target

source_ds = xr.open_zarr(source_path)
target_grid_ds = xr.open_zarr(target_grid_path)

if pre_chunk_def is not None:
source_ds = source_ds.chunk(**pre_chunk_def)

if weights_path:
from ndpyramid.regrid import _reconstruct_xesmf_weights

weights = _reconstruct_xesmf_weights(xr.open_zarr(weights_path))
print(weights_path)
regridder = xe.Regridder(
source_ds,
target_grid_ds,
Expand All @@ -450,9 +475,9 @@ def regrid(source_path: UPath, target_grid_path: UPath, weights_path: UPath = No
regridded_ds.attrs.update(
{'title': source_ds.attrs['title']}, **get_cf_global_attrs(version=version)
)

regridded_ds = set_zarr_encoding(regridded_ds)
blocking_to_zarr(ds=regridded_ds, target=target, validate=True, write_empty_chunks=True)

return target


Expand Down
4 changes: 4 additions & 0 deletions cmip6_downscaling/methods/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import pathlib
import re

Expand Down Expand Up @@ -66,6 +67,9 @@ def validate_zarr_store(target: str, raise_on_error=True) -> bool:
return True


is_cached = functools.partial(validate_zarr_store, raise_on_error=False)


def zmetadata_exists(path: UPath):
'''temporary workaround until path.exists() works'''

Expand Down
Loading