Skip to content

Commit

Permalink
updates to era_downloader from masked_fwp branch
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 25, 2024
1 parent 157c102 commit 524dd11
Showing 1 changed file with 90 additions and 21 deletions.
111 changes: 90 additions & 21 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dask
import dask.array as da
import numpy as np
import pandas as pd
from rex import init_logger

from sup3r.preprocessing import Cacher, Loader
Expand Down Expand Up @@ -77,7 +78,8 @@ def __init__(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
"""
self.year = year
self.month = month
Expand All @@ -95,7 +97,15 @@ def __init__(
def get_hours(self):
"""ERA5 is hourly and EDA is 3-hourly. Check and warn for incompatible
requests."""
if self.product_type == 'reanalysis':
if self.product_type in (
'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members',
):
hours = ['00:00']
elif self.product_type in (
'reanalysis',
'monthly_averaged_reanalysis_by_hour_of_day',
):
hours = [str(n).zfill(2) + ':00' for n in range(0, 24)]
else:
hours = [str(n).zfill(2) + ':00' for n in range(0, 24, 3)]
Expand Down Expand Up @@ -241,30 +251,38 @@ def download_process_combine(self):
time_dict = {
'year': self.year,
'month': self.month,
'day': self.days,
'time': self.hours,
}
if 'monthly' not in self.product_type:
time_dict['day'] = self.days

if sfc_check:
tmp_file = self.get_tmp_file(self.surface_file)
self.download_file(
self.sfc_file_variables,
time_dict=time_dict,
area=self.area,
out_file=self.surface_file,
out_file=tmp_file,
level_type='single',
overwrite=self.overwrite,
product_type=self.product_type,
)
os.replace(tmp_file, self.surface_file)
logger.info('Moved %s to %s', tmp_file, self.surface_file)
if level_check:
tmp_file = self.get_tmp_file(self.level_file)
self.download_file(
self.level_file_variables,
time_dict=time_dict,
area=self.area,
out_file=self.level_file,
out_file=tmp_file,
level_type='pressure',
levels=self.levels,
overwrite=self.overwrite,
product_type=self.product_type,
)
os.replace(tmp_file, self.level_file)
logger.info('Moved %s to %s', tmp_file, self.level_file)
if sfc_check or level_check:
self.process_and_combine()

Expand Down Expand Up @@ -299,16 +317,23 @@ def download_file(
List of pressure levels to download, if level_type == 'pressure'
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
overwrite : bool
Whether to overwrite existing file
"""
if not os.path.exists(out_file) or overwrite:
if os.path.exists(out_file) and not cls._can_skip_file(out_file):
os.remove(out_file)

if not cls._can_skip_file(out_file) or overwrite:
msg = (
f'Downloading {variables} to {out_file} with levels '
f'= {levels}.'
)
logger.info(msg)
dataset = f'reanalysis-era5-{level_type}-levels'
if 'monthly' in product_type:
dataset += '-monthly-means'
entry = {
'product_type': product_type,
'format': 'netcdf',
Expand All @@ -318,11 +343,11 @@ def download_file(
entry.update(time_dict)
if level_type == 'pressure':
entry['pressure_level'] = levels
logger.info(f'Calling CDS-API with {entry}.')
cds_api_client = cls.get_cds_client()
cds_api_client.retrieve(
f'reanalysis-era5-{level_type}-levels', entry, out_file
logger.info(
'Calling CDS-API with dataset=%s, entry=%s.', dataset, entry
)
cds_api_client = cls.get_cds_client()
cds_api_client.retrieve(dataset, entry, out_file)
else:
logger.info(f'File already exists: {out_file}.')

Expand All @@ -334,6 +359,11 @@ def process_surface_file(self):
ds = self.convert_z(ds, name='orog')
ds = standardize_names(ds, ERA_NAME_MAP)
ds = standardize_values(ds)

if 'monthly' in self.product_type:
ds['time'] = pd.DatetimeIndex(
[f'{self.year}-{str(self.month).zfill(2)}-01']
)
ds.compute().to_netcdf(tmp_file, format='NETCDF4', engine='h5netcdf')
os.replace(tmp_file, self.surface_file)
logger.info(
Expand Down Expand Up @@ -398,6 +428,10 @@ def process_level_file(self):
ds = standardize_names(ds, ERA_NAME_MAP)
ds = standardize_values(ds)
ds = self.add_pressure(ds)
if 'monthly' in self.product_type:
ds['time'] = pd.DatetimeIndex(
[f'{self.year}-{str(self.month).zfill(2)}-01']
)
ds.compute().to_netcdf(tmp_file, format='NETCDF4', engine='h5netcdf')
os.replace(tmp_file, self.level_file)
logger.info(
Expand All @@ -407,7 +441,13 @@ def process_level_file(self):

def process_and_combine(self):
"""Process variables and combine."""
if not os.path.exists(self.monthly_file) or self.overwrite:

if os.path.exists(self.monthly_file) and not self._can_skip_file(
self.monthly_file
):
os.remove(self.monthly_file)

if not self._can_skip_file(self.monthly_file) or self.overwrite:
files = []
if os.path.exists(self.level_file):
logger.info(f'Processing {self.level_file}.')
Expand All @@ -431,7 +471,9 @@ def process_and_combine(self):
def get_monthly_file(self):
"""Download level and surface files, process variables, and combine
processed files. Includes checks for shape and variables."""
if os.path.exists(self.monthly_file) and self.overwrite:
if os.path.exists(self.monthly_file) and (
not self._can_skip_file(self.monthly_file) or self.overwrite
):
os.remove(self.monthly_file)

if not os.path.exists(self.monthly_file):
Expand Down Expand Up @@ -497,7 +539,8 @@ def run_month(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
"""
variables = variables if isinstance(variables, list) else [variables]
for var in variables:
Expand Down Expand Up @@ -559,16 +602,20 @@ def run_for_var(
Variable to download.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
"""
yearly_var_file = yearly_file_pattern.format(year=year, var=variable)
if os.path.exists(yearly_var_file) and not overwrite:
logger.info(
'%s already exists and overwrite=False.', yearly_var_file
if yearly_file_pattern is not None:
yearly_var_file = yearly_file_pattern.format(
year=year, var=variable
)
if os.path.exists(yearly_var_file) and not overwrite:
logger.info(
'%s already exists and overwrite=False.', yearly_var_file
)
msg = 'file_pattern must have {year}, {month}, and {var} format keys'
assert all(
key in monthly_file_pattern
Expand Down Expand Up @@ -652,7 +699,8 @@ def run(
and wind components.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
'ensemble_members', 'monthly_averaged_reanalysis',
'monthly_averaged_ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'
Expand All @@ -676,7 +724,7 @@ def run(
res_kwargs=res_kwargs,
)

if (
if yearly_file_pattern is not None and (
cls.all_vars_exist(
year=year,
file_pattern=yearly_file_pattern,
Expand Down Expand Up @@ -743,6 +791,27 @@ def make_yearly_var_file(
files, outfile, chunks=chunks, res_kwargs=res_kwargs
)

@classmethod
def _can_skip_file(cls, file):
"""Make sure existing file has successfully downloaded and can be
opened."""
if not os.path.exists(file):
return False

logger.info(
'%s already exists. Making sure it downloaded successfully.', file
)
openable = True
try:
_ = Loader(file)
except Exception as e:
msg = 'Could not open %s. %s Will redownload.'
logger.warning(msg, file, e)
warn(msg % (file, e))
openable = False

return openable

@classmethod
def _combine_files(cls, files, outfile, chunks='auto', res_kwargs=None):
if not os.path.exists(outfile):
Expand Down

0 comments on commit 524dd11

Please sign in to comment.