Skip to content

Commit

Permalink
flatten method added to Sup3rX and monthly averaged product types add…
Browse files Browse the repository at this point in the history
…ed to era downloader
  • Loading branch information
bnb32 committed Sep 27, 2024
1 parent 95b0038 commit 2037101
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 19 deletions.
20 changes: 20 additions & 0 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,26 @@ def unflatten(self, grid_shape):
warn(msg)
return self

def flatten(self):
"""Flatten rasterized dataset so that there is only a single spatial
dimension."""
if not self.flattened:
self._ds = self._ds.stack(
{Dimension.FLATTENED_SPATIAL: Dimension.dims_2d()}
)
self._ds = self._ds.assign(
{
Dimension.FLATTENED_SPATIAL: np.arange(
len(self._ds[Dimension.FLATTENED_SPATIAL])
)
}
)
else:
msg = 'Dataset is already flattened'
logger.warning(msg)
warn(msg)
return self

def _qa(self, feature):
"""Get qa info for given feature."""
info = {}
Expand Down
65 changes: 49 additions & 16 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
"""
self.year = year
self.month = month
Expand All @@ -95,7 +96,15 @@ def __init__(
def get_hours(self):
"""ERA5 is hourly and EDA is 3-hourly. Check and warn for incompatible
requests."""
if self.product_type == 'reanalysis':
if self.product_type in (
'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members',
):
hours = ['00:00']
elif self.product_type in (
'reanalysis',
'monthly_averaged_reanalysis_by_hour_of_day',
):
hours = [str(n).zfill(2) + ':00' for n in range(0, 24)]
else:
hours = [str(n).zfill(2) + ':00' for n in range(0, 24, 3)]
Expand Down Expand Up @@ -241,9 +250,11 @@ def download_process_combine(self):
time_dict = {
'year': self.year,
'month': self.month,
'day': self.days,
'time': self.hours,
}
if 'monthly' not in self.product_type:
time_dict['day'] = self.days

if sfc_check:
tmp_file = self.get_tmp_file(self.surface_file)
self.download_file(
Expand Down Expand Up @@ -305,16 +316,23 @@ def download_file(
List of pressure levels to download, if level_type == 'pressure'
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
overwrite : bool
Whether to overwrite existing file
"""
if os.path.exists(out_file) and not cls._can_skip_file(out_file):
os.remove(out_file)

if not cls._can_skip_file(out_file) or overwrite:
msg = (
f'Downloading {variables} to {out_file} with levels '
f'= {levels}.'
)
logger.info(msg)
dataset = f'reanalysis-era5-{level_type}-levels'
if 'monthly' in product_type:
dataset += '-monthly-means'
entry = {
'product_type': product_type,
'format': 'netcdf',
Expand All @@ -324,11 +342,11 @@ def download_file(
entry.update(time_dict)
if level_type == 'pressure':
entry['pressure_level'] = levels
logger.info(f'Calling CDS-API with {entry}.')
cds_api_client = cls.get_cds_client()
cds_api_client.retrieve(
f'reanalysis-era5-{level_type}-levels', entry, out_file
logger.info(
'Calling CDS-API with dataset=%s, entry=%s.', dataset, entry
)
cds_api_client = cls.get_cds_client()
cds_api_client.retrieve(dataset, entry, out_file)
else:
logger.info(f'File already exists: {out_file}.')

Expand Down Expand Up @@ -413,6 +431,12 @@ def process_level_file(self):

def process_and_combine(self):
"""Process variables and combine."""

if os.path.exists(self.monthly_file) and not self._can_skip_file(
self.monthly_file
):
os.remove(self.monthly_file)

if not self._can_skip_file(self.monthly_file) or self.overwrite:
files = []
if os.path.exists(self.level_file):
Expand Down Expand Up @@ -505,7 +529,8 @@ def run_month(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
"""
variables = variables if isinstance(variables, list) else [variables]
for var in variables:
Expand Down Expand Up @@ -567,16 +592,20 @@ def run_for_var(
Variable to download.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
"""
yearly_var_file = yearly_file_pattern.format(year=year, var=variable)
if os.path.exists(yearly_var_file) and not overwrite:
logger.info(
'%s already exists and overwrite=False.', yearly_var_file
if yearly_file_pattern is not None:
yearly_var_file = yearly_file_pattern.format(
year=year, var=variable
)
if os.path.exists(yearly_var_file) and not overwrite:
logger.info(
'%s already exists and overwrite=False.', yearly_var_file
)
msg = 'file_pattern must have {year}, {month}, and {var} format keys'
assert all(
key in monthly_file_pattern
Expand Down Expand Up @@ -660,7 +689,8 @@ def run(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'
Expand All @@ -684,7 +714,7 @@ def run(
res_kwargs=res_kwargs,
)

if (
if yearly_file_pattern is not None and (
cls.all_vars_exist(
year=year,
file_pattern=yearly_file_pattern,
Expand Down Expand Up @@ -758,6 +788,9 @@ def _can_skip_file(cls, file):
if not os.path.exists(file):
return False

logger.info(
'%s already exists. Making sure it downloaded successfully.', file
)
openable = True
try:
_ = Loader(file)
Expand Down
7 changes: 4 additions & 3 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def merge_datasets(files, **kwargs):
out = xr.merge(dsets, **get_class_kwargs(xr.merge, kwargs))
msg = ('Merged time index does not have the same number of time steps '
'(%s) as the sum of the individual time index steps (%s).')
merged_size = out.time.size
summed_size = pd.concat(time_indices).drop_duplicates().size
assert merged_size == summed_size, msg % (merged_size, summed_size)
if hasattr(out, 'time'):
merged_size = out.time.size
summed_size = pd.concat(time_indices).drop_duplicates().size
assert merged_size == summed_size, msg % (merged_size, summed_size)
return out


Expand Down

0 comments on commit 2037101

Please sign in to comment.