diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 80cb0a5f7..84ab7b156 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -16,6 +16,7 @@ from sup3r.preprocessing.utilities import ( _lowered, _mem_check, + compute_if_dask, dims_array_tuple, is_type_of, ordered_array, @@ -90,6 +91,7 @@ def __init__(self, ds: Union[xr.Dataset, Self]): """ self._ds = ds self._features = None + self._meta = None self.time_slice = None def parse_keys(self, keys): @@ -529,10 +531,20 @@ def grid_shape(self): @property def meta(self): - """Return dataframe of flattened lat / lon values.""" - return pd.DataFrame( - columns=Dimension.coords_2d(), data=self.lat_lon.reshape((-1, 2)) - ) + """Return dataframe of flattened lat / lon values. Can also be set to + include additional data like elevation, country, state, etc""" + if self._meta is None: + self._meta = pd.DataFrame( + columns=Dimension.coords_2d(), + data=self.lat_lon.reshape((-1, 2)), + ) + return self._meta + + @meta.setter + def meta(self, meta): + """Set meta data. Used to update meta with additional info from + datasets like WTK and NSRDB.""" + self._meta = meta def unflatten(self, grid_shape): """Convert flattened dataset into rasterized dataset with the given @@ -550,6 +562,26 @@ def unflatten(self, grid_shape): warn(msg) return self + def _qa(self, feature): + """Get qa info for given feature.""" + info = {} + logger.info('Running qa on feature: %s', feature) + nan_count = 100 * np.isnan(self[feature].data).sum() + nan_perc = nan_count / self[feature].size + info['nan_perc'] = compute_if_dask(nan_perc) + info['std'] = compute_if_dask(self[feature].std().data) + info['mean'] = compute_if_dask(self[feature].mean().data) + info['min'] = compute_if_dask(self[feature].min().data) + info['max'] = compute_if_dask(self[feature].max().data) + return info + + def qa(self): + """Check NaNs and stats for all features.""" + qa_info = {} + for f in self.features: + qa_info[f] = self._qa(f) + return qa_info + def __mul__(self, other): """Multiply ``Sup3rX`` object by other. Used to compute weighted means and stdevs.""" diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index 1a5871e24..03aea6ee1 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -430,12 +430,12 @@ def _get_single_step_enhance(self, step): s_enhance = 1 t_enhance = 1 else: - s_enhance = np.prod(s_enhancements[:model_step]) - t_enhance = np.prod(t_enhancements[:model_step]) + s_enhance = int(np.prod(s_enhancements[:model_step])) + t_enhance = int(np.prod(t_enhancements[:model_step])) else: - s_enhance = np.prod(s_enhancements[: model_step + 1]) - t_enhance = np.prod(t_enhancements[: model_step + 1]) + s_enhance = int(np.prod(s_enhancements[: model_step + 1])) + t_enhance = int(np.prod(t_enhancements[: model_step + 1])) step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) return step diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index 4a2a5e635..520bae861 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -70,12 +70,16 @@ def __init__( self.chunks = chunks BASE_LOADER = BaseLoader or self.BASE_LOADER self.res = BASE_LOADER(self.file_paths, **self.res_kwargs) - data = self._load().astype(np.float32) - data = self._add_attrs(lower_names(data)) - data = standardize_names(standardize_values(data), FEATURE_NAMES) + data = lower_names(self._load()) + data = self._add_attrs(data) + data = standardize_values(data) + data = standardize_names(data, FEATURE_NAMES).astype(np.float32) features = list(data.dims) if features == [] else features self.data = data[features] if features != 'all' else data + if 'meta' in self.res: + self.data.meta = self.res.meta + def _parse_chunks(self, dims, feature=None): """Get chunks for given dimensions from ``self.chunks``.""" chunks = copy.deepcopy(self.chunks) diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index 227f4cafe..8f61a3eb3 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -148,7 +148,13 @@ def _get_data_vars(self, dims): data_vars['elevation'] = (dims, elev) feats = set(self.res.h5.datasets) - exclude = {'meta', 'time_index', 'coordinates'} + exclude = { + 'meta', + 'time_index', + 'coordinates', + 'latitude', + 'longitude', + } for f in feats - exclude: data_vars[f] = self._get_dset_tuple( dset=f, dims=dims, chunks=chunks diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index c937032ce..777f6b5dc 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -212,20 +212,12 @@ def update_lr_data(self): def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] + logger.info('Checking for NaNs after regridding') + qa_info = self.lr_data.qa() for f in self.lr_data.features: - logger.info( - f'Checking for NaNs after regridding, for feature: {f}' - ) - nan_perc = ( - 100 - * np.isnan(self.lr_data[f].data).sum() - / self.lr_data[f].size - ) + nan_perc = qa_info[f]['nan_perc'] if nan_perc > 0: - msg = ( - f'{f} data has {np.asarray(nan_perc):.3f}% NaN ' - 'values!' - ) + msg = f'{f} data has {nan_perc:.3f}% NaN ' 'values!' if nan_perc < 10: fill_feats.append(f) logger.warning(msg) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 406385d41..d4b62e4ad 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -20,6 +20,7 @@ from sup3r.postprocessing.writers.base import OutputHandler from sup3r.preprocessing.accessor import Sup3rX from sup3r.preprocessing.base import Sup3rMeta +from sup3r.preprocessing.cachers import Cacher from sup3r.preprocessing.derivers.utilities import SolarZenith from sup3r.preprocessing.loaders import Loader from sup3r.preprocessing.names import Dimension @@ -158,7 +159,6 @@ def coords(self): coord: (Dimension.dims_2d(), self.hr_lat_lon[..., i]) for i, coord in enumerate(Dimension.coords_2d()) } - coords[Dimension.TIME] = self.hr_time_index return coords @property @@ -265,8 +265,13 @@ def data(self): if not os.path.exists(cache_fp): tmp_fp = cache_fp + f'{generate_random_string(10)}.tmp' - data.load().to_netcdf(tmp_fp, format='NETCDF4', engine='h5netcdf') + Cacher.write_netcdf(tmp_fp, data) shutil.move(tmp_fp, cache_fp) + + if Dimension.TIME not in data.dims: + data = data.expand_dims(**{Dimension.TIME: self.hr_shape[-1]}) + data = data.reindex({Dimension.TIME: self.hr_time_index}) + data = Sup3rX(data.ffill(Dimension.TIME)) return data def get_data(self): @@ -311,17 +316,8 @@ def get_data(self): self.source_file, self.feature, ) - arr = ( - da.from_array(hr_data) - if hr_data.shape == self.hr_shape - else da.repeat( - da.from_array(hr_data[..., None]), - len(self.hr_time_index), - axis=-1, - ) - ) data_vars = { - self.feature: (Dimension.dims_3d(), arr.astype(np.float32)) + self.feature: (Dimension.dims_2d(), hr_data.astype(np.float32)) } ds = xr.Dataset(coords=self.coords, data_vars=data_vars) return Sup3rX(ds) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index b1c7e2e27..1526c543f 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -9,12 +9,14 @@ import logging import os +import pprint from calendar import monthrange from warnings import warn import dask import dask.array as da import numpy as np +from rex import init_logger from sup3r.preprocessing import Loader from sup3r.preprocessing.loaders.utilities import ( @@ -43,7 +45,7 @@ def __init__( month, area, levels, - combined_out_pattern, + monthly_file_pattern, overwrite=False, variables=None, product_type='reanalysis', @@ -61,7 +63,7 @@ def __init__( [max_lat, min_lon, min_lat, max_lon] levels : list List of pressure levels to download. - combined_out_pattern : str + monthly_file_pattern : str Pattern for combined monthly output file. Must include year and month format keys. e.g. 'era5_{year}_{month}_combined.nc' overwrite : bool @@ -78,8 +80,7 @@ def __init__( self.area = area self.levels = levels self.overwrite = overwrite - self.combined_out_pattern = combined_out_pattern - self._combined_file = None + self.monthly_file_pattern = monthly_file_pattern self._variables = variables self.sfc_file_variables = [] self.level_file_variables = [] @@ -112,43 +113,28 @@ def days(self): ] @property - def combined_file(self): - """Get name of file from combined surface and level files""" - if self._combined_file is None: - if '{var}' in self.combined_out_pattern: - self._combined_file = self.combined_out_pattern.format( - year=self.year, - month=str(self.month).zfill(2), - var='_'.join(self.variables), - ) - else: - self._combined_file = self.combined_out_pattern.format( - year=self.year, month=str(self.month).zfill(2) - ) - os.makedirs(os.path.dirname(self._combined_file), exist_ok=True) - return self._combined_file + def monthly_file(self): + """Name of file with all surface and level variables for a given month + and year.""" + monthly_file = self.monthly_file_pattern.replace( + '{var}', '_'.join(self.variables) + ).format(year=self.year, month=str(self.month).zfill(2)) + os.makedirs(os.path.dirname(monthly_file), exist_ok=True) + return monthly_file @property def surface_file(self): """Get name of file with variables from single level download""" - basedir = os.path.dirname(self.combined_file) - basename = '' - if '{var}' in self.combined_out_pattern: - basename += '_'.join(self.variables) + '_' - basename += f'sfc_{self.year}_' - basename += f'{str(self.month).zfill(2)}.nc' - return os.path.join(basedir, basename) + basedir = os.path.dirname(self.monthly_file) + basename = os.path.basename(self.monthly_file) + return os.path.join(basedir, f'sfc_{basename}') @property def level_file(self): """Get name of file with variables from pressure level download""" - basedir = os.path.dirname(self.combined_file) - basename = '' - if '{var}' in self.combined_out_pattern: - basename += '_'.join(self.variables) + '_' - basename += f'levels_{self.year}_' - basename += f'{str(self.month).zfill(2)}.nc' - return os.path.join(basedir, basename) + basedir = os.path.dirname(self.monthly_file) + basename = os.path.basename(self.monthly_file) + return os.path.join(basedir, f'level_{basename}') @classmethod def get_tmp_file(cls, file): @@ -432,7 +418,7 @@ def _write_dsets(cls, files, out_file, kwargs=None): def process_and_combine(self): """Process variables and combine.""" - if not os.path.exists(self.combined_file) or self.overwrite: + if not os.path.exists(self.monthly_file) or self.overwrite: files = [] if os.path.exists(self.level_file): logger.info(f'Processing {self.level_file}.') @@ -443,31 +429,23 @@ def process_and_combine(self): self.process_surface_file() files.append(self.surface_file) - logger.info(f'Combining {files} to {self.combined_file}.') kwargs = {'compat': 'override'} - try: - self._write_dsets( - files, out_file=self.combined_file, kwargs=kwargs - ) - except Exception as e: - msg = f'Error combining {files}.' - logger.error(msg) - raise RuntimeError(msg) from e + self._combine_files(files, self.monthly_file, kwargs) if os.path.exists(self.level_file): os.remove(self.level_file) if os.path.exists(self.surface_file): os.remove(self.surface_file) else: - logger.info(f'{self.combined_file} already exists.') + logger.info(f'{self.monthly_file} already exists.') def get_monthly_file(self): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables.""" - if os.path.exists(self.combined_file) and self.overwrite: - os.remove(self.combined_file) + if os.path.exists(self.monthly_file) and self.overwrite: + os.remove(self.monthly_file) - if not os.path.exists(self.combined_file): + if not os.path.exists(self.monthly_file): self.download_process_combine() @classmethod @@ -535,7 +513,7 @@ def run_month( month, area, levels, - combined_out_pattern, + monthly_file_pattern, overwrite=False, variables=None, product_type='reanalysis', @@ -553,7 +531,7 @@ def run_month( [max_lat, min_lon, min_lat, max_lon] levels : list List of pressure levels to download. - combined_out_pattern : str + monthly_file_pattern : str Pattern for combined monthly output file. Must include year and month format keys. e.g. 'era5_{year}_{month}_combined.nc' overwrite : bool @@ -572,7 +550,7 @@ def run_month( month=month, area=area, levels=levels, - combined_out_pattern=combined_out_pattern, + monthly_file_pattern=monthly_file_pattern, overwrite=overwrite, variables=[var], product_type=product_type, @@ -585,8 +563,8 @@ def run_year( year, area, levels, - combined_out_pattern, - combined_yearly_file=None, + monthly_file_pattern, + yearly_file=None, overwrite=False, max_workers=None, variables=None, @@ -603,10 +581,10 @@ def run_year( [max_lat, min_lon, min_lat, max_lon] levels : list List of pressure levels to download. - combined_out_pattern : str + monthly_file_pattern : str Pattern for combined monthly output file. Must include year and month format keys. e.g. 'era5_{year}_{month}_combined.nc' - combined_yearly_file : str + yearly_file : str Name of yearly file made from monthly combined files. overwrite : bool Whether to overwrite existing files. @@ -620,12 +598,18 @@ def run_year( Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread', 'ensemble_members' """ + if ( + yearly_file is not None + and os.path.exists(yearly_file) + and not overwrite + ): + logger.info('%s already exists and overwrite=False.', yearly_file) msg = ( - 'combined_out_pattern must have {year}, {month}, and {var} ' + 'monthly_file_pattern must have {year}, {month}, and {var} ' 'format keys' ) assert all( - key in combined_out_pattern + key in monthly_file_pattern for key in ('{year}', '{month}', '{var}') ), msg @@ -637,7 +621,7 @@ def run_year( month=month, area=area, levels=levels, - combined_out_pattern=combined_out_pattern, + monthly_file_pattern=monthly_file_pattern, overwrite=overwrite, variables=[var], product_type=product_type, @@ -650,12 +634,10 @@ def run_year( dask.compute(*tasks, scheduler='threads', num_workers=max_workers) for month in range(1, 13): - cls.make_monthly_file(year, month, combined_out_pattern, variables) + cls.make_monthly_file(year, month, monthly_file_pattern, variables) - if combined_yearly_file is not None: - cls.make_yearly_file( - year, combined_out_pattern, combined_yearly_file - ) + if yearly_file is not None: + cls.make_yearly_file(year, monthly_file_pattern, yearly_file) @classmethod def make_monthly_file(cls, year, month, file_pattern, variables): @@ -687,11 +669,14 @@ def make_monthly_file(cls, year, month, file_pattern, variables): outfile = file_pattern.replace('_{var}', '').format( year=year, month=str(month).zfill(2) ) + cls._combine_files(files, outfile) + @classmethod + def _combine_files(cls, files, outfile, kwargs): if not os.path.exists(outfile): logger.info(f'Combining {files} into {outfile}.') try: - cls._write_dsets(files, out_file=outfile) + cls._write_dsets(files, out_file=outfile, kwargs=kwargs) except Exception as e: msg = f'Error combining {files}.' logger.error(msg) @@ -725,14 +710,15 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): ) for month in range(1, 13) ] + kwargs = {'combine': 'nested', 'concat_dim': 'time'} + cls._combine_files(files, yearly_file, kwargs) - if not os.path.exists(yearly_file): - kwargs = {'combine': 'nested', 'concat_dim': 'time'} - try: - cls._write_dsets(files, out_file=yearly_file, kwargs=kwargs) - except Exception as e: - msg = f'Error combining {files}' - logger.error(msg) - raise RuntimeError(msg) from e - else: - logger.info(f'{yearly_file} already exists.') + @classmethod + def run_qa(cls, file, res_kwargs=None, log_file=None): + """Check for NaN values and log min / max / mean / stds for all + variables.""" + + logger = init_logger(__name__, log_level='DEBUG', log_file=log_file) + with Loader(file, res_kwargs=res_kwargs) as res: + logger.info('Running qa on file: %s', file) + logger.info('\n%s', pprint.pformat(res.qa(), indent=2)) diff --git a/tests/loaders/test_file_loading.py b/tests/loaders/test_file_loading.py index 15c54b674..72d16b7d9 100644 --- a/tests/loaders/test_file_loading.py +++ b/tests/loaders/test_file_loading.py @@ -202,7 +202,7 @@ def test_load_nc(): def test_load_h5(): """Test simple h5 file loading. Also checks renaming elevation -> topography. Also makes sure that general loader matches type specific - loader""" + loader. Also checks that meta data is carried into loader object""" chunks = {'space': 200, 'time': 200} loader = LoaderH5(pytest.FP_WTK, chunks=chunks) @@ -224,6 +224,7 @@ def test_load_h5(): assert np.array_equal(loader.as_array(), gen_loader.as_array()) loader_attrs = {f: loader[f].attrs for f in feats} resource_attrs = Resource(pytest.FP_WTK).attrs + assert np.array_equal(loader.meta, loader.res.meta) matching_feats = set(Resource(pytest.FP_WTK).datasets).intersection(feats) assert all(loader_attrs[f] == resource_attrs[f] for f in matching_feats)