Skip to content

Commit

Permalink
added collection test with more chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 31, 2024
1 parent af2336b commit 7e90cb0
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sup3r/postprocessing/writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
'u': {
'scale_factor': 100.0,
'units': 'm s-1',
'dtype': 'uint16',
'dtype': 'int16',
'chunks': (2000, 500),
'min': -120,
'max': 120,
},
'v': {
'scale_factor': 100.0,
'units': 'm s-1',
'dtype': 'uint16',
'dtype': 'int16',
'chunks': (2000, 500),
'min': -120,
'max': 120,
Expand Down
13 changes: 9 additions & 4 deletions sup3r/postprocessing/writers/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def invert_uv_features(cls, data, features, lat_lon, max_workers=None):
for f in features
if re.match('u_(.*?)m'.lower(), f.lower())
]

if heights:
logger.info(
'Converting u/v to ws/wd for H5 output with max_workers=%s',
Expand Down Expand Up @@ -144,10 +145,14 @@ def _transform_output(cls, data, features, lat_lon, max_workers=None):
Max workers to use for inverse transform. If None the max_workers
will be estimated based on memory limits.
"""

cls.invert_uv_features(
data, features, lat_lon, max_workers=max_workers
)
if any(
re.match('u_(.*?)m'.lower(), f.lower())
or re.match('v_(.*?)m'.lower(), f.lower())
for f in features
):
cls.invert_uv_features(
data, features, lat_lon, max_workers=max_workers
)
features = cls.get_renamed_features(features)
data = cls.enforce_limits(features=features, data=data)
return data, features
Expand Down
86 changes: 85 additions & 1 deletion sup3r/utilities/pytest/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Testing helpers."""

import os
from itertools import product

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -257,6 +258,89 @@ def sample_batch(self):
return BatchHandlerTester


def make_collect_chunks(td):
"""Make fake h5 chunked output files for collection tests.
Parameters
----------
td : tempfile.TemporaryDirectory
Test TemporaryDirectory
Returns
-------
out_files : list
List of filepaths to chunked files.
data : ndarray
(spatial_1, spatial_2, temporal, features)
High resolution forward pass output
ws_true : ndarray
Windspeed between 0 and 20 in shape (spatial_1, spatial_2, temporal, 1)
wd_true : ndarray
Windir between 0 and 360 in shape (spatial_1, spatial_2, temporal, 1)
features : list
List of feature names corresponding to the last dimension of data
['windspeed_100m', 'winddirection_100m']
hr_lat_lon : ndarray
Array of lat/lon for hr data. (spatial_1, spatial_2, 2)
Last dimension has ordering (lat, lon)
hr_times : list
List of np.datetime64 objects for hr data.
"""

features = ['windspeed_100m', 'winddirection_100m']
model_meta_data = {'foo': 'bar'}
shape = (50, 50, 96, 1)
ws_true = RANDOM_GENERATOR.uniform(0, 20, shape)
wd_true = RANDOM_GENERATOR.uniform(0, 360, shape)
data = np.concatenate((ws_true, wd_true), axis=3)
lat = np.linspace(90, 0, 50)
lon = np.linspace(-180, 0, 50)
lon, lat = np.meshgrid(lon, lat)
hr_lat_lon = np.dstack((lat, lon))

gids = np.arange(np.prod(shape[:2]))
gids = gids.reshape(shape[:2])

hr_times = pd_date_range(
'20220101', '20220103', freq='1800s', inclusive='left'
)

t_slices_hr = np.array_split(np.arange(len(hr_times)), 4)
t_slices_hr = [slice(s[0], s[-1] + 1) for s in t_slices_hr]
s_slices_hr = np.array_split(np.arange(shape[0]), 4)
s_slices_hr = [slice(s[0], s[-1] + 1) for s in s_slices_hr]

out_pattern = os.path.join(td, 'fp_out_{t}_{s}.h5')
out_files = []
for t, slice_hr in enumerate(t_slices_hr):
for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)):
out_file = out_pattern.format(
t=str(t).zfill(6),
s=str(s).zfill(6)
)
out_files.append(out_file)
OutputHandlerH5._write_output(
data[s1_hr, s2_hr, slice_hr, :],
features,
hr_lat_lon[s1_hr, s2_hr],
hr_times[slice_hr],
out_file,
meta_data=model_meta_data,
max_workers=1,
gids=gids[s1_hr, s2_hr],
)

return (
out_files,
data,
ws_true,
wd_true,
features,
hr_lat_lon,
hr_times
)


def make_fake_h5_chunks(td):
"""Make fake h5 chunked output files for a 5x spatial 2x temporal
multi-node forward pass output.
Expand Down Expand Up @@ -352,7 +436,7 @@ def make_fake_h5_chunks(td):
s_slices_lr,
s_slices_hr,
low_res_lat_lon,
low_res_times,
low_res_times
)


Expand Down
37 changes: 36 additions & 1 deletion tests/output/test_output_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
invert_uv,
transform_rotate_wind,
)
from sup3r.utilities.pytest.helpers import make_fake_h5_chunks
from sup3r.utilities.pytest.helpers import (
make_collect_chunks,
make_fake_h5_chunks,
)
from sup3r.utilities.utilities import RANDOM_GENERATOR


Expand Down Expand Up @@ -125,6 +128,38 @@ def test_invert_uv_inplace():
assert np.allclose(data[..., 1], wd)


def test_general_collect():
"""Make sure general file collection gives complete meta, time_index, and
data array."""

with tempfile.TemporaryDirectory() as td:
fp_out = os.path.join(td, 'out_combined.h5')

out = make_collect_chunks(td)
out_files, data, features, hr_lat_lon, hr_times = (
out[0],
out[1],
out[-3],
out[-2],
out[-1],
)

CollectorH5.collect(out_files, fp_out, features=features)

with ResourceX(fp_out) as res:
lat_lon = res['meta'][['latitude', 'longitude']].values
time_index = res['time_index'].values
collect_data = np.dstack([res[f, :, :] for f in features])
base_data = data.transpose(2, 0, 1, 3).reshape(
(len(hr_times), -1, len(features))
)
base_data = np.around(base_data.astype(np.float32), 2)
hr_lat_lon = hr_lat_lon.astype(np.float32)
assert np.array_equal(hr_times, time_index)
assert np.array_equal(hr_lat_lon.reshape((-1, 2)), lat_lon)
assert np.array_equal(base_data, collect_data)


def test_h5_out_and_collect(collect_check):
"""Test h5 file output writing and collection with dummy data"""

Expand Down

0 comments on commit 7e90cb0

Please sign in to comment.