Skip to content

Commit

Permalink
Merge pull request #438 from pnuu/bugfix-bilinear-cache-loading
Browse files Browse the repository at this point in the history
Fix using cached LUTs in bilinear resampler
  • Loading branch information
pnuu authored Sep 12, 2022
2 parents 363167f + 1f9f20d commit 1bc2995
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 13 deletions.
6 changes: 3 additions & 3 deletions pyresample/bilinear/xarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _slice_data(self, data, fill_value):
def from_delayed(delayeds, shp):
return [da.from_delayed(d, shp, np.float32) for d in delayeds]

data = _check_data_shape(data, self._valid_input_index)
data = _check_data_shape(data, self._source_geo_def.shape)
if data.ndim == 2:
shp = self.bilinear_s.shape
else:
Expand Down Expand Up @@ -264,10 +264,10 @@ def _get_valid_input_index(source_geo_def,
return valid_input_index, source_lons, source_lats


def _check_data_shape(data, input_idxs):
def _check_data_shape(data, input_xy_shape):
"""Check data shape and adjust if necessary."""
# Handle multiple datasets
if data.ndim > 2 and data.shape[0] * data.shape[1] == input_idxs.shape[0]:
if data.ndim > 2 and data.shape[0] * data.shape[1] == input_xy_shape[0]:
# Move the "channel" dimension first
data = da.moveaxis(data, -1, 0)

Expand Down
4 changes: 3 additions & 1 deletion pyresample/ewa/ewa.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def fornav(cols, rows, area_def, data_in,
"""
if isinstance(data_in, (tuple, list)):
# we can only support one data type per call at this time
assert(in_arr.dtype == data_in[0].dtype for in_arr in data_in[1:])
for in_arr in data_in[1:]:
if in_arr.dtype != data_in[0].dtype:
raise ValueError("All input arrays must be the same dtype")
else:
# assume they gave us a single numpy array-like object
data_in = [data_in]
Expand Down
6 changes: 3 additions & 3 deletions pyresample/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ def intersection(self, other_arc):
ab_ = a__.hdistance(b__)
cd_ = c__.hdistance(d__)

if(((i in (a__, b__)) or
if (((i in (a__, b__)) or
(abs(a__.hdistance(i) + b__.hdistance(i) - ab_) < EPSILON)) and
((i in (c__, d__)) or
(abs(c__.hdistance(i) + d__.hdistance(i) - cd_) < EPSILON))):
((i in (c__, d__)) or
(abs(c__.hdistance(i) + d__.hdistance(i) - cd_) < EPSILON))):
return i
return None

Expand Down
10 changes: 5 additions & 5 deletions pyresample/spherical_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, lon=None, lat=None,
x__=None, y__=None, z__=None, R__=1):
self.R__ = R__
if lat is not None and lon is not None:
if not(-180 <= lon <= 180 and -90 <= lat <= 90):
if not (-180 <= lon <= 180 and -90 <= lat <= 90):
raise ValueError('Illegal (lon, lat) coordinates: (%s, %s)'
% (lon, lat))
self.lat = math.radians(lat)
Expand All @@ -75,8 +75,8 @@ def _update_lonlat(self):

def __ne__(self, other):
"""Check inequality."""
if(abs(self.lat - other.lat) < EPSILON and
abs(self.lon - other.lon) < EPSILON):
if (abs(self.lat - other.lat) < EPSILON and
abs(self.lon - other.lon) < EPSILON):
return 0
else:
return 1
Expand Down Expand Up @@ -286,8 +286,8 @@ def intersection(self, other_arc):
ab_ = a__.distance(b__)
cd_ = c__.distance(d__)

if(abs(a__.distance(i) + b__.distance(i) - ab_) < EPSILON and
abs(c__.distance(i) + d__.distance(i) - cd_) < EPSILON):
if (abs(a__.distance(i) + b__.distance(i) - ab_) < EPSILON and
abs(c__.distance(i) + d__.distance(i) - cd_) < EPSILON):
return i
return None

Expand Down
43 changes: 43 additions & 0 deletions pyresample/test/test_bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,22 @@ def setUp(self):
[-5326849.0625, -5326849.0625,
5326849.0625, 5326849.0625])

# Area that partially overlaps the source data
self.target_def_partial = geometry.AreaDefinition('area_partial_overlap',
'Europe (3km, HRV, VTC)',
'areaD',
{'a': '6378144.0',
'b': '6356759.0',
'lat_0': '50.00',
'lat_ts': '50.00',
'lon_0': '8.00',
'proj': 'stere'},
4, 4,
[59559.320999999996,
-909968.64000000001,
2920503.401,
1490031.3600000001])

# Input data around the target pixel at 0.63388324, 55.08234642,
in_shape = (100, 100)
self.data1 = DataArray(da.ones((in_shape[0], in_shape[1])), dims=('y', 'x'))
Expand Down Expand Up @@ -1199,6 +1215,33 @@ def test_save_and_load_bil_info(self):
finally:
shutil.rmtree(tempdir, ignore_errors=True)

def test_get_sample_from_cached_bil_info(self):
"""Test getting data using pre-calculated resampling info."""
import os
import shutil
from tempfile import mkdtemp

from pyresample.bilinear import XArrayBilinearResampler

resampler = XArrayBilinearResampler(self.source_def, self.target_def_partial,
self.radius)
resampler.get_bil_info()

try:
tempdir = mkdtemp()
filename = os.path.join(tempdir, "test.zarr")

resampler.save_resampling_info(filename)

assert os.path.exists(filename)

new_resampler = XArrayBilinearResampler(self.source_def, self.target_def_partial,
self.radius)
new_resampler.load_resampling_info(filename)
_ = new_resampler.get_sample_from_bil_info(self.data1)
finally:
shutil.rmtree(tempdir, ignore_errors=True)


def test_check_fill_value():
"""Test that fill_value replacement/adjustment works."""
Expand Down
2 changes: 1 addition & 1 deletion pyresample/test/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ def assert_np_dict_allclose(dict1, dict2):
try:
np.testing.assert_allclose(val, dict2[key])
except TypeError:
assert(val == dict2[key])
assert val == dict2[key]


class TestSwathDefinition(unittest.TestCase):
Expand Down

0 comments on commit 1bc2995

Please sign in to comment.