Skip to content

Commit

Permalink
Merge pull request #206 from NREL/gb/no_regrid
Browse files Browse the repository at this point in the history
Gb/no regrid
  • Loading branch information
grantbuster authored Apr 8, 2024
2 parents a446b80 + dd85e63 commit c754191
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 22 deletions.
8 changes: 0 additions & 8 deletions sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,14 +709,6 @@ def preflight(self):

start = self.temporal_slice.start
stop = self.temporal_slice.stop
n_steps = self.n_tsteps
msg = (f'Temporal slice step ({self.temporal_slice.step}) does not '
f'evenly divide the number of time steps ({n_steps})')
check = self.temporal_slice.step is None
check = check or n_steps % self.temporal_slice.step == 0
if not check:
logger.warning(msg)
warnings.warn(msg)

msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger '
'than the number of time steps in the raw data '
Expand Down
87 changes: 73 additions & 14 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self,
regrid_workers=1,
load_cached=True,
shuffle_time=False,
regrid_lr=True,
s_enhance=1,
t_enhance=1,
val_split=0.0):
Expand All @@ -61,6 +62,10 @@ def __init__(self,
is called.
shuffle_time : bool
Whether to shuffle time indices prior to training/validation split
regrid_lr : bool
Flag to regrid the low-res handler data to the high-res handler
grid. This will take care of any minor inconsistencies in different
projections. Disable this if the grids are known to be the same.
s_enhance : int
Spatial enhancement factor
t_enhance : int
Expand Down Expand Up @@ -95,6 +100,7 @@ def __init__(self,
self._means = None
self._stds = None
self._is_normalized = False
self._regrid_lr = regrid_lr
self._norm_workers = self.lr_dh.norm_workers

if self.try_load and self.load_cached:
Expand Down Expand Up @@ -248,14 +254,66 @@ def normalize(self, means=None, stds=None, max_workers=None):
if stds is None:
stds = self.stds

self._normalize_lr(means, stds)
self._normalize_hr(means, stds)

def _normalize_lr(self, means, stds):
"""Normalize the low-resolution data features including in the
low-res data handler
Note that self.lr_data is usually a unique regridded array but if
regridding was not performed then it is just a sliced *view* of
self.lr_dh.data and the super().normalize() operation will have applied
to that data already.
Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
"""

logger.info('Normalizing low resolution data features='
f'{self.lr_dh.features}')
super().normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)
self.lr_dh.normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)

if id(self.lr_dh.data) != id(self.lr_data.base):
self.lr_dh.normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)
else:
self.lr_dh._is_normalized = True

def _normalize_hr(self, means, stds):
"""Normalize the high-resolution data features including in the
high-res data handler
Note that self.hr_data is usually just a sliced *view* of
self.hr_dh.data but if the *view* is broken then it will have to be
normalized too
Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
"""

logger.info('Normalizing high resolution data features='
f'{self.hr_dh.features}')
Expand All @@ -264,8 +322,6 @@ def normalize(self, means=None, stds=None, max_workers=None):
max_workers=self.hr_dh.norm_workers)

if id(self.hr_data.base) != id(self.hr_dh.data):
# self.hr_data is usually just a sliced view of self.hr_dh.data
# but if the view is broken then it will have to be normalized too
mean_arr = np.array([means[fn] for fn in self.hr_dh.features])
std_arr = np.array([stds[fn] for fn in self.hr_dh.features])
self.hr_data = (self.hr_data - mean_arr) / std_arr
Expand Down Expand Up @@ -579,16 +635,19 @@ def get_lr_regridded_data(self):
"""Regrid low_res data for all requested noncached features. Load
cached features if available and overwrite=False"""

logger.info('Regridding low resolution feature data.')
regridder = self.get_regridder()
if self._regrid_lr:
logger.info('Regridding low resolution feature data.')
regridder = self.get_regridder()

fnames = set(self.noncached_features)
fnames = fnames.intersection(set(self.lr_dh.features))
for fname in fnames:
fidx = self.lr_dh.features.index(fname)
tmp = regridder(self.lr_input_data[..., fidx])
tmp = tmp.reshape(self.lr_required_shape)
self.lr_data[..., fidx] = tmp
fnames = set(self.noncached_features)
fnames = fnames.intersection(set(self.lr_dh.features))
for fname in fnames:
fidx = self.lr_dh.features.index(fname)
tmp = regridder(self.lr_input_data[..., fidx])
tmp = tmp.reshape(self.lr_required_shape)
self.lr_data[..., fidx] = tmp
else:
self.lr_data = self.lr_input_data

if self.load_cached:
fnames = set(self.cached_features)
Expand Down
57 changes: 57 additions & 0 deletions tests/data_handling/test_dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,63 @@ def test_normalization(cache,
rtol=rtol, atol=atol)


def test_no_regrid(log=False, full_shape=(20, 20), sample_shape=(10, 10, 4)):
"""Test no regridding of the LR data with correct normalization and
view/slice of the lr dataset"""
if log:
init_logger('sup3r', log_level='DEBUG')

s_enhance = 2
t_enhance = 2

hr_dh = DataHandlerH5(FP_WTK, FEATURES[0], target=TARGET_COORD,
shape=full_shape, sample_shape=sample_shape,
temporal_slice=slice(None, None, 10),
worker_kwargs=dict(max_workers=1),
val_split=0.0)
lr_handler = DataHandlerH5(FP_WTK, FEATURES[1], target=TARGET_COORD,
shape=full_shape,
sample_shape=(sample_shape[0] // s_enhance,
sample_shape[1] // s_enhance,
sample_shape[2] // t_enhance),
temporal_slice=slice(None, -10,
t_enhance * 10),
hr_spatial_coarsen=2, cache_pattern=None,
worker_kwargs=dict(max_workers=1),
val_split=0.0)

hr_dh0 = copy.deepcopy(hr_dh)
hr_dh1 = copy.deepcopy(hr_dh)
lr_dh0 = copy.deepcopy(lr_handler)
lr_dh1 = copy.deepcopy(lr_handler)

ddh0 = DualDataHandler(hr_dh0, lr_dh0, s_enhance=s_enhance,
t_enhance=t_enhance, regrid_lr=True)
ddh1 = DualDataHandler(hr_dh1, lr_dh1, s_enhance=s_enhance,
t_enhance=t_enhance, regrid_lr=False)

_ = DualBatchHandler([ddh0], norm=True)
_ = DualBatchHandler([ddh1], norm=True)

hr_m0 = np.mean(ddh0.hr_data, axis=(0, 1, 2))
lr_m0 = np.mean(ddh0.lr_data, axis=(0, 1, 2))
hr_m1 = np.mean(ddh1.hr_data, axis=(0, 1, 2))
lr_m1 = np.mean(ddh1.lr_data, axis=(0, 1, 2))
assert np.allclose(hr_m0, hr_m1)
assert np.allclose(lr_m0, lr_m1)
assert np.allclose(hr_m0, 0, atol=1e-3)
assert np.allclose(lr_m0, 0, atol=1e-6)

hr_s0 = np.std(ddh0.hr_data, axis=(0, 1, 2))
lr_s0 = np.std(ddh0.lr_data, axis=(0, 1, 2))
hr_s1 = np.std(ddh1.hr_data, axis=(0, 1, 2))
lr_s1 = np.std(ddh1.lr_data, axis=(0, 1, 2))
assert np.allclose(hr_s0, hr_s1)
assert np.allclose(lr_s0, lr_s1)
assert np.allclose(hr_s0, 1, atol=1e-3)
assert np.allclose(lr_s0, 1, atol=1e-6)


@pytest.mark.parametrize(['lr_features', 'hr_features', 'hr_exo_features'],
[(['U_100m'], ['U_100m', 'V_100m'], ['V_100m']),
(['U_100m'], ['U_100m', 'V_100m'], ('V_100m',)),
Expand Down

0 comments on commit c754191

Please sign in to comment.