Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 18, 2024
1 parent 3117212 commit 78e69ce
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 52 deletions.
30 changes: 16 additions & 14 deletions sup3r/preprocessing/data_handlers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,17 @@ def __post_init__(self):
initialize `self.data`, and update `self.data` for each model step with
rasterized exo data."""
self.data = {self.feature: {'steps': []}}
en_check = all('s_enhance' in v for v in self.steps)
en_check = en_check and all('t_enhance' in v for v in self.steps)
en_check = en_check or self.models is not None
msg = (
f'{self.__class__.__name__} needs s_enhance and t_enhance '
'provided in each step in steps list or models'
)
assert en_check, msg
if self.steps is None:
self.steps = self.get_exo_steps(self.models)
self.steps = self.get_exo_steps(self.feature, self.models)
else:
en_check = all('s_enhance' in v for v in self.steps)
en_check = en_check and all('t_enhance' in v for v in self.steps)
en_check = en_check or self.models is not None
msg = (
f'{self.__class__.__name__} needs s_enhance and t_enhance '
'provided in each step in steps list or models'
)
assert en_check, msg
self.s_enhancements, self.t_enhancements = self._get_all_enhancement()
msg = (
'Need to provide s_enhance and t_enhance for each model'
Expand All @@ -350,23 +351,24 @@ def __post_init__(self):
)
self.get_all_step_data()

def get_exo_steps(self, models):
"""Get list of steps describing how to exogenous data for the given
@classmethod
def get_exo_steps(cls, feature, models):
"""Get list of steps describing how to use exogenous data for the given
feature in the list of given models. This checks the input and
exo feature lists for each model step and adds that step if the
given feature is found in the list."""
steps = []
for i, model in enumerate(models):
is_sfc_model = model.__class__.__name__ == 'SurfaceSpatialMetModel'
if (
self.feature.lower() in _lowered(model.lr_features)
feature.lower() in _lowered(model.lr_features)
or is_sfc_model
):
steps.append({'model': i, 'combine_type': 'input'})
if self.feature.lower() in _lowered(model.hr_exo_features):
if feature.lower() in _lowered(model.hr_exo_features):
steps.append({'model': i, 'combine_type': 'layer'})
if (
self.feature.lower() in _lowered(model.hr_out_features)
feature.lower() in _lowered(model.hr_out_features)
or is_sfc_model
):
steps.append({'model': i, 'combine_type': 'output'})
Expand Down
27 changes: 6 additions & 21 deletions tests/forward_pass/test_forward_pass_exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,33 +1034,18 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(
)
forward_pass = ForwardPass(handler)

exo_handler = ExoDataHandler(
**{
'feature': 'topography',
'models': forward_pass.model.models,
'file_paths': input_files,
'source_file': pytest.FP_WTK,
'input_handler_kwargs': {'target': target, 'shape': shape},
'cache_dir': td,
}
)
assert exo_handler.get_exo_steps(forward_pass.model.models) == [
assert ExoDataHandler.get_exo_steps(
'topography', forward_pass.model.models
) == [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
]

exo_handler = ExoDataHandler(
**{
'feature': 'sza',
'models': forward_pass.model.models,
'file_paths': input_files,
'input_handler_kwargs': {'target': target, 'shape': shape},
'cache_dir': td,
}
)
assert exo_handler.get_exo_steps(forward_pass.model.models) == [
assert ExoDataHandler.get_exo_steps(
'sza', forward_pass.model.models
) == [
{'model': 0, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'input'},
{'model': 2, 'combine_type': 'input'},
Expand Down
27 changes: 10 additions & 17 deletions tests/utilities/test_era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ class EraDownloaderTester(EraDownloader):
# pylint: disable=unused-argument
@classmethod
def download_file(
cls,
variables,
out_file,
level_type,
levels=None,
**kwargs
cls, variables, out_file, level_type, levels=None, **kwargs
):
"""Download either single-level or pressure-level file"""
shape = (10, 10, 100)
Expand All @@ -37,16 +32,14 @@ def download_file(
'100m_u_component_of_wind': 'u100',
'100m_v_component_of_wind': 'v100',
'u_component_of_wind': 'u',
'v_component_of_wind': 'v'}
'v_component_of_wind': 'v',
}

if 'geopotential' in variables:
features.append('z')
features.extend([v for f, v in name_map.items() if f in variables])

nc = make_fake_dset(
shape=shape,
features=features
)
nc = make_fake_dset(shape=shape, features=features)
if 'z' in nc:
if level_type == 'single':
nc['z'] = (nc['z'].dims, np.zeros(nc['z'].shape))
Expand All @@ -62,7 +55,7 @@ def test_era_dl(tmpdir_factory):
"""Test basic post proc for era downloader."""

variables = ['zg', 'orog', 'u', 'v', 'pressure']
combined_out_pattern = os.path.join(
file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{month}_{var}.nc'
)
year = 2000
Expand All @@ -74,13 +67,13 @@ def test_era_dl(tmpdir_factory):
month=month,
area=area,
levels=levels,
combined_out_pattern=combined_out_pattern,
monthly_file_pattern=file_pattern,
variables=variables,
)
for v in variables:
standard_name = FEATURE_NAMES.get(v, v)
tmp = xr_open_mfdataset(
combined_out_pattern.format(year=2000, month='01', var=v)
file_pattern.format(year=2000, month='01', var=v)
)
assert standard_name in tmp

Expand All @@ -90,7 +83,7 @@ def test_era_dl_year(tmpdir_factory):
year."""

variables = ['zg', 'orog', 'u', 'v', 'pressure']
combined_out_pattern = os.path.join(
file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{month}_{var}.nc'
)
yearly_file = os.path.join(tmpdir_factory.mktemp('tmp'), 'era5_final.nc')
Expand All @@ -99,8 +92,8 @@ def test_era_dl_year(tmpdir_factory):
area=[50, -130, 23, -65],
levels=[1000, 900, 800],
variables=variables,
combined_out_pattern=combined_out_pattern,
combined_yearly_file=yearly_file,
monthly_file_pattern=file_pattern,
yearly_file=yearly_file,
max_workers=1,
)

Expand Down

0 comments on commit 78e69ce

Please sign in to comment.