Skip to content

Commit

Permalink
not caching exo data with time if time independent. this can be a hug…
Browse files Browse the repository at this point in the history
…e cache if using hr topo in a temporal model.
  • Loading branch information
bnb32 committed Aug 18, 2024
1 parent 6bbdf98 commit 9957800
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 110 deletions.
40 changes: 36 additions & 4 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sup3r.preprocessing.utilities import (
_lowered,
_mem_check,
compute_if_dask,
dims_array_tuple,
is_type_of,
ordered_array,
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(self, ds: Union[xr.Dataset, Self]):
"""
self._ds = ds
self._features = None
self._meta = None
self.time_slice = None

def parse_keys(self, keys):
Expand Down Expand Up @@ -529,10 +531,20 @@ def grid_shape(self):

@property
def meta(self):
"""Return dataframe of flattened lat / lon values."""
return pd.DataFrame(
columns=Dimension.coords_2d(), data=self.lat_lon.reshape((-1, 2))
)
"""Return dataframe of flattened lat / lon values. Can also be set to
include additional data like elevation, country, state, etc"""
if self._meta is None:
self._meta = pd.DataFrame(
columns=Dimension.coords_2d(),
data=self.lat_lon.reshape((-1, 2)),
)
return self._meta

@meta.setter
def meta(self, meta):
"""Set meta data. Used to update meta with additional info from
datasets like WTK and NSRDB."""
self._meta = meta

def unflatten(self, grid_shape):
"""Convert flattened dataset into rasterized dataset with the given
Expand All @@ -550,6 +562,26 @@ def unflatten(self, grid_shape):
warn(msg)
return self

def _qa(self, feature):
"""Get qa info for given feature."""
info = {}
logger.info('Running qa on feature: %s', feature)
nan_count = 100 * np.isnan(self[feature].data).sum()
nan_perc = nan_count / self[feature].size
info['nan_perc'] = compute_if_dask(nan_perc)
info['std'] = compute_if_dask(self[feature].std().data)
info['mean'] = compute_if_dask(self[feature].mean().data)
info['min'] = compute_if_dask(self[feature].min().data)
info['max'] = compute_if_dask(self[feature].max().data)
return info

def qa(self):
"""Check NaNs and stats for all features."""
qa_info = {}
for f in self.features:
qa_info[f] = self._qa(f)
return qa_info

def __mul__(self, other):
"""Multiply ``Sup3rX`` object by other. Used to compute weighted means
and stdevs."""
Expand Down
8 changes: 4 additions & 4 deletions sup3r/preprocessing/data_handlers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,12 @@ def _get_single_step_enhance(self, step):
s_enhance = 1
t_enhance = 1
else:
s_enhance = np.prod(s_enhancements[:model_step])
t_enhance = np.prod(t_enhancements[:model_step])
s_enhance = int(np.prod(s_enhancements[:model_step]))
t_enhance = int(np.prod(t_enhancements[:model_step]))

else:
s_enhance = np.prod(s_enhancements[: model_step + 1])
t_enhance = np.prod(t_enhancements[: model_step + 1])
s_enhance = int(np.prod(s_enhancements[: model_step + 1]))
t_enhance = int(np.prod(t_enhancements[: model_step + 1]))
step.update({'s_enhance': s_enhance, 't_enhance': t_enhance})
return step

Expand Down
10 changes: 7 additions & 3 deletions sup3r/preprocessing/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ def __init__(
self.chunks = chunks
BASE_LOADER = BaseLoader or self.BASE_LOADER
self.res = BASE_LOADER(self.file_paths, **self.res_kwargs)
data = self._load().astype(np.float32)
data = self._add_attrs(lower_names(data))
data = standardize_names(standardize_values(data), FEATURE_NAMES)
data = lower_names(self._load())
data = self._add_attrs(data)
data = standardize_values(data)
data = standardize_names(data, FEATURE_NAMES).astype(np.float32)
features = list(data.dims) if features == [] else features
self.data = data[features] if features != 'all' else data

if 'meta' in self.res:
self.data.meta = self.res.meta

def _parse_chunks(self, dims, feature=None):
"""Get chunks for given dimensions from ``self.chunks``."""
chunks = copy.deepcopy(self.chunks)
Expand Down
8 changes: 7 additions & 1 deletion sup3r/preprocessing/loaders/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,13 @@ def _get_data_vars(self, dims):
data_vars['elevation'] = (dims, elev)

feats = set(self.res.h5.datasets)
exclude = {'meta', 'time_index', 'coordinates'}
exclude = {
'meta',
'time_index',
'coordinates',
'latitude',
'longitude',
}
for f in feats - exclude:
data_vars[f] = self._get_dset_tuple(
dset=f, dims=dims, chunks=chunks
Expand Down
16 changes: 4 additions & 12 deletions sup3r/preprocessing/rasterizers/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,12 @@ def update_lr_data(self):
def check_regridded_lr_data(self):
"""Check for NaNs after regridding and do NN fill if needed."""
fill_feats = []
logger.info('Checking for NaNs after regridding')
qa_info = self.lr_data.qa()
for f in self.lr_data.features:
logger.info(
f'Checking for NaNs after regridding, for feature: {f}'
)
nan_perc = (
100
* np.isnan(self.lr_data[f].data).sum()
/ self.lr_data[f].size
)
nan_perc = qa_info[f]['nan_perc']
if nan_perc > 0:
msg = (
f'{f} data has {np.asarray(nan_perc):.3f}% NaN '
'values!'
)
msg = f'{f} data has {nan_perc:.3f}% NaN ' 'values!'
if nan_perc < 10:
fill_feats.append(f)
logger.warning(msg)
Expand Down
20 changes: 8 additions & 12 deletions sup3r/preprocessing/rasterizers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sup3r.postprocessing.writers.base import OutputHandler
from sup3r.preprocessing.accessor import Sup3rX
from sup3r.preprocessing.base import Sup3rMeta
from sup3r.preprocessing.cachers import Cacher
from sup3r.preprocessing.derivers.utilities import SolarZenith
from sup3r.preprocessing.loaders import Loader
from sup3r.preprocessing.names import Dimension
Expand Down Expand Up @@ -158,7 +159,6 @@ def coords(self):
coord: (Dimension.dims_2d(), self.hr_lat_lon[..., i])
for i, coord in enumerate(Dimension.coords_2d())
}
coords[Dimension.TIME] = self.hr_time_index
return coords

@property
Expand Down Expand Up @@ -265,8 +265,13 @@ def data(self):

if not os.path.exists(cache_fp):
tmp_fp = cache_fp + f'{generate_random_string(10)}.tmp'
data.load().to_netcdf(tmp_fp, format='NETCDF4', engine='h5netcdf')
Cacher.write_netcdf(tmp_fp, data)
shutil.move(tmp_fp, cache_fp)

if Dimension.TIME not in data.dims:
data = data.expand_dims(**{Dimension.TIME: self.hr_shape[-1]})
data = data.reindex({Dimension.TIME: self.hr_time_index})
data = Sup3rX(data.ffill(Dimension.TIME))
return data

def get_data(self):
Expand Down Expand Up @@ -311,17 +316,8 @@ def get_data(self):
self.source_file,
self.feature,
)
arr = (
da.from_array(hr_data)
if hr_data.shape == self.hr_shape
else da.repeat(
da.from_array(hr_data[..., None]),
len(self.hr_time_index),
axis=-1,
)
)
data_vars = {
self.feature: (Dimension.dims_3d(), arr.astype(np.float32))
self.feature: (Dimension.dims_2d(), hr_data.astype(np.float32))
}
ds = xr.Dataset(coords=self.coords, data_vars=data_vars)
return Sup3rX(ds)
Expand Down
Loading

0 comments on commit 9957800

Please sign in to comment.