Skip to content

Commit

Permalink
refactor: interp_method -> interp_kwargs (includes method and run_lev…
Browse files Browse the repository at this point in the history
…el_check keys. the latter is default False, as this takes a long time since it has to load arrays into memory to compute min / max levels. ) Modified the linear interpolation method to use the 2 closest levels rather than the two closest levels which also happen to be above and below the requested level. This speeds up the interpolation by orders of magnitude.
  • Loading branch information
bnb32 committed Nov 2, 2024
1 parent 7977a7b commit d8c98c1
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 112 deletions.
12 changes: 7 additions & 5 deletions sup3r/preprocessing/data_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
nan_method_kwargs: Optional[dict] = None,
BaseLoader: Optional[Callable] = None,
FeatureRegistry: Optional[dict] = None,
interp_method: str = 'linear',
interp_kwargs: Optional[dict] = None,
cache_kwargs: Optional[dict] = None,
**kwargs,
):
Expand Down Expand Up @@ -120,9 +120,11 @@ def __init__(
Dictionary of
:class:`~sup3r.preprocessing.derivers.methods.DerivedFeature`
objects used for derivations
interp_method : str
Interpolation method to use for height interpolation. e.g. Deriving
u_20m from u_10m and u_100m. Options are "linear" and "log". See
interp_kwargs : dict | None
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check" keys. Method specifies how to perform height
interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options
are "linear" and "log". See
:py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation`
cache_kwargs : dict | None
Dictionary with kwargs for caching wrangled data. This should at
Expand Down Expand Up @@ -183,7 +185,7 @@ def __init__(
hr_spatial_coarsen=hr_spatial_coarsen,
nan_method_kwargs=nan_method_kwargs,
FeatureRegistry=FeatureRegistry,
interp_method=interp_method,
interp_kwargs=interp_kwargs,
)

if self.cache is not None:
Expand Down
42 changes: 24 additions & 18 deletions sup3r/preprocessing/derivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
data: Union[Sup3rX, Sup3rDataset],
features,
FeatureRegistry=None,
interp_method='linear',
interp_kwargs=None,
):
"""
Parameters
Expand All @@ -54,15 +54,18 @@ def __init__(
lookups. When the :class:`Deriver` is asked to derive a feature
that is not found in the :class:`Rasterizer` data it will look for
a method to derive the feature in the registry.
interp_method : str
Interpolation method to use for height interpolation. e.g. Deriving
u_20m from u_10m and u_100m. Options are "linear" and "log"
"""
interp_kwargs : dict | None
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check" keys. Method specifies how to perform height
interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options
are "linear" and "log". See
:py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation`
""" # pylint: disable=line-too-long
if FeatureRegistry is not None:
self.FEATURE_REGISTRY = FeatureRegistry

super().__init__(data=data)
self.interp_method = interp_method
self.interp_kwargs = interp_kwargs
features = parse_to_list(data=data, features=features)
new_features = [f for f in features if f not in self.data]
for f in new_features:
Expand Down Expand Up @@ -228,7 +231,7 @@ def derive(self, feature) -> Union[np.ndarray, da.core.Array]:
'Attempting level interpolation for "%s"', feature
)
return self.do_level_interpolation(
feature, interp_method=self.interp_method
feature, interp_kwargs=self.interp_kwargs
)

msg = (
Expand Down Expand Up @@ -278,7 +281,7 @@ def get_multi_level_data(self, feature):
lev_array = None

if fstruct.basename in self.data:
var_array = self.data[fstruct.basename][...]
var_array = self.data[fstruct.basename].data.astype(np.float32)

if fstruct.height is not None and var_array is not None:
msg = (
Expand All @@ -298,7 +301,7 @@ def get_multi_level_data(self, feature):
lev_array = lev_array[..., 0] - lev_array[..., 1]
else:
lev_array = da.broadcast_to(
self.data[Dimension.HEIGHT][...].astype(np.float32),
self.data[Dimension.HEIGHT].astype(np.float32),
var_array.shape,
)
elif var_array is not None:
Expand All @@ -309,13 +312,13 @@ def get_multi_level_data(self, feature):
)
assert Dimension.PRESSURE_LEVEL in self.data, msg
lev_array = da.broadcast_to(
self.data[Dimension.PRESSURE_LEVEL][...].astype(np.float32),
self.data[Dimension.PRESSURE_LEVEL].astype(np.float32),
var_array.shape,
)
return var_array, lev_array

def do_level_interpolation(
self, feature, interp_method='linear'
self, feature, interp_kwargs=None
) -> xr.DataArray:
"""Interpolate over height or pressure to derive the given feature."""
ml_var, ml_levs = self.get_multi_level_data(feature)
Expand Down Expand Up @@ -351,7 +354,7 @@ def do_level_interpolation(
lev_array=lev_array,
var_array=var_array,
level=np.float32(level),
interp_method=interp_method,
interp_kwargs=interp_kwargs,
)
return xr.DataArray(
data=_rechunk_if_dask(out),
Expand All @@ -373,7 +376,7 @@ def __init__(
hr_spatial_coarsen=1,
nan_method_kwargs=None,
FeatureRegistry=None,
interp_method='linear',
interp_kwargs=None,
):
"""
Parameters
Expand All @@ -398,16 +401,19 @@ def __init__(
will be passed to :meth:`Sup3rX.interpolate_na`.
FeatureRegistry : dict
Dictionary of :class:`DerivedFeature` objects used for derivations
interp_method : str
Interpolation method to use for height interpolation. e.g. Deriving
u_20m from u_10m and u_100m. Options are "linear" and "log"
"""
interp_kwargs : dict | None
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check" keys. Method specifies how to perform height
interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options
are "linear" and "log". See
:py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation`
""" # pylint: disable=line-too-long

super().__init__(
data=data,
features=features,
FeatureRegistry=FeatureRegistry,
interp_method=interp_method,
interp_kwargs=interp_kwargs,
)

if time_roll != 0:
Expand Down
1 change: 1 addition & 0 deletions sup3r/preprocessing/loaders/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_coords(res):
lats = res[Dimension.LATITUDE].data.astype(np.float32)
lons = res[Dimension.LONGITUDE].data.astype(np.float32)

# remove time dimension if there's a single time step
if lats.ndim == 3:
lats = lats.squeeze()
if lons.ndim == 3:
Expand Down
102 changes: 22 additions & 80 deletions sup3r/utilities/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import dask.array as da
import numpy as np

from sup3r.utilities.utilities import RANDOM_GENERATOR

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,34 +45,29 @@ def get_level_masks(cls, lev_array, level):
to the one requested.
(lat, lon, time, level)
"""
over_mask = lev_array > level
under_levs = (
da.ma.masked_array(lev_array, over_mask)
if ~over_mask.sum() >= lev_array[..., 0].size
else lev_array
)
argmin1 = da.argmin(da.abs(under_levs - level), axis=-1, keepdims=True)
argmin1 = da.argmin(da.abs(lev_array - level), axis=-1, keepdims=True)
lev_indices = da.broadcast_to(
da.arange(lev_array.shape[-1]), lev_array.shape
)
mask1 = lev_indices == argmin1

over_levs = (
da.ma.masked_array(lev_array, ~over_mask)
if over_mask.sum() >= lev_array[..., 0].size
else da.ma.masked_array(lev_array, mask1)
)
argmin2 = da.argmin(da.abs(over_levs - level), axis=-1, keepdims=True)
other_levs = da.ma.masked_array(lev_array, mask1)
argmin2 = da.argmin(da.abs(other_levs - level), axis=-1, keepdims=True)
mask2 = lev_indices == argmin2
return mask1, mask2

@classmethod
def _lin_interp(cls, lev_samps, var_samps, level):
"""Linearly interpolate between levels."""
diff = lev_samps[1] - lev_samps[0]
alpha = (level - lev_samps[0]) / diff
diff = da.map_blocks(lambda x, y: x - y, lev_samps[1], lev_samps[0])
alpha = da.map_blocks(lambda x, d: (level - x) / d, lev_samps[0], diff)
alpha = da.where(diff == 0, 0, alpha)
return var_samps[0] * (1 - alpha) + var_samps[1] * alpha
return da.map_blocks(
lambda x, y, a: x * (1 - a) + y * a,
var_samps[0],
var_samps[1],
alpha,
)

@classmethod
def _log_interp(cls, lev_samps, var_samps, level):
Expand Down Expand Up @@ -102,7 +95,7 @@ def interp_to_level(
lev_array: Union[np.ndarray, da.core.Array],
var_array: Union[np.ndarray, da.core.Array],
level,
interp_method='linear',
interp_kwargs=None,
):
"""Interpolate var_array to the given level.
Expand All @@ -121,14 +114,21 @@ def interp_to_level(
level : float
level or levels to interpolate to (e.g. final desired hub height
above surface elevation)
interp_kwargs: dict | None
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check" keys
Returns
-------
out : Union[np.ndarray, da.core.Array]
Interpolated var_array
(lat, lon, time)
Interpolated var_array (lat, lon, time)
"""
cls._check_lev_array(lev_array, levels=[level])
interp_kwargs = interp_kwargs or {}
interp_method = interp_kwargs.get('method', 'linear')
run_level_check = interp_kwargs.get('run_level_check', False)

if run_level_check:
cls._check_lev_array(lev_array, levels=[level])
levs = da.ma.masked_array(lev_array, da.isnan(lev_array))
mask1, mask2 = cls.get_level_masks(levs, level)
lev1 = da.where(mask1, lev_array, np.nan)
Expand All @@ -148,7 +148,6 @@ def interp_to_level(
out = cls._lin_interp(
lev_samps=[lev1, lev2], var_samps=[var1, var2], level=level
)

return out

@classmethod
Expand Down Expand Up @@ -214,60 +213,3 @@ def _check_lev_array(cls, lev_array, levels):
)
logger.warning(msg)
warn(msg)

@classmethod
def prep_level_interp(cls, var_array, lev_array, levels):
"""Prepare var_array interpolation. Check level ranges and add noise to
mask locations.
Parameters
----------
var_array : Union[np.ndarray, da.core.Array]
Array of variable data, for example u-wind in a 4D array of shape
(time, vertical, lat, lon)
lev_array : Union[np.ndarray, da.core.Array]
Array of height or pressure values corresponding to the wrf source
data in the same shape as var_array. If this is height and the
requested levels are hub heights above surface, lev_array should be
the geopotential height corresponding to every var_array index
relative to the surface elevation (subtract the elevation at the
surface from the geopotential height)
levels : float | list
level or levels to interpolate to (e.g. final desired hub heights
above surface elevation)
Returns
-------
lev_array : Union[np.ndarray, da.core.Array]
Array of levels with noise added to mask locations.
levels : list
List of levels to interpolate to.
"""

msg = (
'Input arrays must be the same shape.'
f'\nvar_array: {var_array.shape}'
f'\nh_array: {lev_array.shape}'
)
assert var_array.shape == lev_array.shape, msg

levels = (
[levels]
if isinstance(levels, (int, float, np.float32))
else levels
)

cls._check_lev_array(lev_array, levels)

# if multiple vertical levels have identical heights at the desired
# interpolation level, interpolation to that value will fail because
# linear slope will be NaN. This is most common if you have multiple
# pressure levels at zero height at the surface in the case that the
# data didnt provide underground data.
for level in levels:
mask = lev_array == level
random = RANDOM_GENERATOR.uniform(-1e-5, 0, size=mask.sum())
lev_array = da.ma.masked_array(lev_array, mask)
lev_array = da.ma.filled(lev_array, random)

return lev_array, levels
Loading

0 comments on commit d8c98c1

Please sign in to comment.