Skip to content

Commit

Permalink
added method to automatically determine exogenous data steps from lr_…
Browse files Browse the repository at this point in the history
…features and hr_exo_features
  • Loading branch information
bnb32 committed Aug 18, 2024
1 parent 0f34e35 commit 3117212
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 87 deletions.
14 changes: 9 additions & 5 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,17 @@ class ForwardPassStrategy:
exo_handler_kwargs : dict | None
Dictionary of args to pass to
:class:`~sup3r.preprocessing.data_handlers.ExoDataHandler` for
extracting exogenous features for multistep foward pass. This should be
extracting exogenous features for foward passes. This should be
a nested dictionary with keys for each exogenous feature. The
dictionaries corresponding to the feature names should include the path
to exogenous data source, the resolution of the exogenous data, and how
the exogenous data should be used in the model. e.g. ``{'topography':
{'file_paths': 'path to input files', 'source_file': 'path to exo
data', 'steps': [..]}``.
to exogenous data source and the files used for input to the forward
passes, at minimum. Can also provide a dictionary of
``input_handler_kwargs`` used for the handler which opens the
exogenous data. e.g.::
{'topography': {
'source_file': ...,
'input_files': ...,
'input_handler_kwargs': {'target': ..., 'shape': ...}}}
bias_correct_method : str | None
Optional bias correction function name that can be imported from the
:mod:`sup3r.bias.bias_transforms` module. This will transform the
Expand Down
61 changes: 41 additions & 20 deletions sup3r/preprocessing/data_handlers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import logging
import pathlib
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np

from sup3r.preprocessing.rasterizers import ExoRasterizer
from sup3r.preprocessing.utilities import log_args
from sup3r.preprocessing.utilities import _lowered, log_args

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -261,11 +261,9 @@ class ExoDataHandler:
Multiple topography arrays at different resolutions for multiple spatial
enhancement steps.
This takes a list of models and information about model steps and uses that
info to compute needed enhancement factors for each step. The requested
feature is then retrieved and rasterized according to the requested target
coordinate and grid shape, for each step. The list of steps are then
updated with the cooresponding exo data.
This takes a list of models and uses the different sets of models features
to retrieve and rasterize exogenous data according to the requested target
coordinate and grid shape, for each model step.
Parameters
----------
Expand All @@ -277,11 +275,13 @@ class ExoDataHandler:
feature : str
Exogenous feature to extract from file_paths
models : list
List of models used with the given steps list. This list of models is
used to determine the input and output resolution and enhancement
factors for each model step which is then used to determine the target
shape for rasterized exo data. If enhancement factors are provided in
the steps list the model list is not needed.
List of models used to get exogenous data. For each model in the list
``lr_features``, ``hr_exo_features``, and ``hr_out_features`` will be
checked and exogenous data will be retrieved based on the resolution
required for that type of feature. e.g. If a model has topography as
a lr and hr_exo feature, and the model performs 5x spatial enhancement
with an input resolution of 30km then topography at 30km and at 6km
will be retrieved. Either this or list of steps needs to be provided.
steps : list
List of dictionaries containing info on which models to use for a given
step index and what type of exo data the step requires. e.g.::
Expand Down Expand Up @@ -318,8 +318,8 @@ class ExoDataHandler:

file_paths: Union[str, list, pathlib.Path]
feature: str
steps: List[dict]
models: Optional[list] = None
steps: Optional[list] = None
source_file: Optional[str] = None
input_handler_name: Optional[str] = None
input_handler_kwargs: Optional[dict] = None
Expand All @@ -328,9 +328,9 @@ class ExoDataHandler:

@log_args
def __post_init__(self):
"""Initialize `self.data`, perform checks on enhancement factors, and
update `self.data` for each model step with rasterized exo data for the
corresponding enhancement factors."""
"""Get list of steps with types of exogenous data needed for retrieval,
initialize `self.data`, and update `self.data` for each model step with
rasterized exo data."""
self.data = {self.feature: {'steps': []}}
en_check = all('s_enhance' in v for v in self.steps)
en_check = en_check and all('t_enhance' in v for v in self.steps)
Expand All @@ -340,17 +340,38 @@ def __post_init__(self):
'provided in each step in steps list or models'
)
assert en_check, msg
if self.steps is None:
self.steps = self.get_exo_steps(self.models)
self.s_enhancements, self.t_enhancements = self._get_all_enhancement()
msg = (
'Need to provide s_enhance and t_enhance for each model'
'step. If the step is temporal only (spatial only) then '
's_enhance = 1 (t_enhance = 1).'
)
assert not any(s is None for s in self.s_enhancements), msg
assert not any(t is None for t in self.t_enhancements), msg

self.get_all_step_data()

def get_exo_steps(self, models):
"""Get list of steps describing how to exogenous data for the given
feature in the list of given models. This checks the input and
exo feature lists for each model step and adds that step if the
given feature is found in the list."""
steps = []
for i, model in enumerate(models):
is_sfc_model = model.__class__.__name__ == 'SurfaceSpatialMetModel'
if (
self.feature.lower() in _lowered(model.lr_features)
or is_sfc_model
):
steps.append({'model': i, 'combine_type': 'input'})
if self.feature.lower() in _lowered(model.hr_exo_features):
steps.append({'model': i, 'combine_type': 'layer'})
if (
self.feature.lower() in _lowered(model.hr_out_features)
or is_sfc_model
):
steps.append({'model': i, 'combine_type': 'output'})
return steps

def get_single_step_data(self, s_enhance, t_enhance):
"""Get exo data for a single model step, with specific enhancement
factors."""
Expand Down Expand Up @@ -440,7 +461,7 @@ def _get_single_step_enhance(self, step):
return step

def _get_all_enhancement(self):
"""Compute enhancement factors for all model steps for all features.
"""Compute enhancement factors for all model steps.
Returns
-------
Expand Down
99 changes: 37 additions & 62 deletions tests/forward_pass/test_forward_pass_exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SurfaceSpatialMetModel,
)
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
from sup3r.preprocessing import Dimension
from sup3r.preprocessing import Dimension, ExoDataHandler
from sup3r.utilities.pytest.helpers import make_fake_nc_file
from sup3r.utilities.utilities import RANDOM_GENERATOR, xr_open_mfdataset

Expand Down Expand Up @@ -107,10 +107,6 @@ def test_fwp_multi_step_model_topo_exoskip(input_files):
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [
{'model': 0, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'input'},
],
}
}

Expand Down Expand Up @@ -661,39 +657,6 @@ def test_fwp_multi_step_wind_hi_res_topo(input_files, gen_config_with_topo):
'time_slice': time_slice,
}

with pytest.raises(RuntimeError):
# should raise error since steps doesn't include
# {'model': 2, 'combine_type': 'input'}
steps = [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
]
exo_handler_kwargs['topography']['steps'] = steps
handler = ForwardPassStrategy(
input_files,
model_kwargs=model_kwargs,
model_class='MultiStepGan',
fwp_chunk_shape=(4, 4, 8),
spatial_pad=1,
temporal_pad=1,
input_handler_kwargs=input_handler_kwargs,
out_pattern=out_files,
exo_handler_kwargs=exo_handler_kwargs,
max_nodes=1,
)
forward_pass = ForwardPass(handler)
forward_pass.run(handler, node_index=0)

steps = [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
{'model': 2, 'combine_type': 'input'},
]
exo_handler_kwargs['topography']['steps'] = steps
handler = ForwardPassStrategy(
input_files,
model_kwargs=model_kwargs,
Expand Down Expand Up @@ -756,10 +719,6 @@ def test_fwp_wind_hi_res_topo_plus_linear(input_files, gen_config_with_topo):
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
],
}
}

Expand Down Expand Up @@ -851,17 +810,12 @@ def test_fwp_multi_step_model_multi_exo(input_files):
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [
{'model': 0, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'input'},
],
},
'sza': {
'file_paths': input_files,
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [{'model': 2, 'combine_type': 'input'}],
},
}

Expand Down Expand Up @@ -999,7 +953,8 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(
_ = s1_model.generate(np.ones((4, 10, 10, 4)), exogenous_data=exo_tmp)

s2_model = Sup3rGan(
gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4)
gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4
)
s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography', 'sza']
s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m']
s2_model.meta['s_enhance'] = 2
Expand Down Expand Up @@ -1048,26 +1003,12 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
],
},
'sza': {
'file_paths': input_files,
'target': target,
'shape': shape,
'cache_dir': td,
'steps': [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
{'model': 2, 'combine_type': 'input'},
{'model': 2, 'combine_type': 'layer'},
],
},
}

Expand All @@ -1092,6 +1033,40 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(
max_nodes=1,
)
forward_pass = ForwardPass(handler)

exo_handler = ExoDataHandler(
**{
'feature': 'topography',
'models': forward_pass.model.models,
'file_paths': input_files,
'source_file': pytest.FP_WTK,
'input_handler_kwargs': {'target': target, 'shape': shape},
'cache_dir': td,
}
)
assert exo_handler.get_exo_steps(forward_pass.model.models) == [
{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'},
{'model': 1, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'layer'},
]

exo_handler = ExoDataHandler(
**{
'feature': 'sza',
'models': forward_pass.model.models,
'file_paths': input_files,
'input_handler_kwargs': {'target': target, 'shape': shape},
'cache_dir': td,
}
)
assert exo_handler.get_exo_steps(forward_pass.model.models) == [
{'model': 0, 'combine_type': 'input'},
{'model': 1, 'combine_type': 'input'},
{'model': 2, 'combine_type': 'input'},
{'model': 2, 'combine_type': 'layer'},
]

forward_pass.run(handler, node_index=0)

for fp in handler.out_files:
Expand Down

0 comments on commit 3117212

Please sign in to comment.