Skip to content

Commit

Permalink
Merge pull request #414 from pnuu/bugfix-gradient-single-channel
Browse files Browse the repository at this point in the history
Fix gradient search for single band data
  • Loading branch information
mraspaud authored Feb 4, 2022
2 parents 1904b05 + 3929c75 commit 1928289
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
5 changes: 4 additions & 1 deletion pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def compute(self, data, fill_value=None, **kwargs):

if fill_value is not None:
res = da.where(np.isnan(res), fill_value, res)
if res.ndim > len(data_dims):
res = res.squeeze()

res = xr.DataArray(res, dims=data_dims, coords=coords)

return res
Expand Down Expand Up @@ -400,6 +403,6 @@ def _concatenate_chunks(chunks):
prev_y = y
res.append(da.concatenate(col, axis=1))

res = da.concatenate(res, axis=2).squeeze()
res = da.concatenate(res, axis=2)

return res
23 changes: 16 additions & 7 deletions pyresample/test/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ def test_resample_area_to_area_3d(self):
assert np.allclose(res[1, :, :], 2.0)
assert np.allclose(res[2, :, :], 3.0)

def test_resample_area_to_area_3d_single_channel(self):
"""Resample area to area, 3d with only a single band."""
data = xr.DataArray(da.ones((1, ) + self.src_area.shape,
dtype=np.float64),
dims=['bands', 'y', 'x'])
res = self.resampler.compute(
data, method='bil').compute(scheduler='single-threaded')
assert res.shape == (1, ) + self.dst_area.shape
assert np.allclose(res[0, :, :], 1.0)

def test_resample_swath_to_area_2d(self):
"""Resample swath to area, 2d."""
data = xr.DataArray(da.ones(self.src_swath.shape, dtype=np.float64),
Expand Down Expand Up @@ -462,11 +472,11 @@ def test_concatenate_chunks():
(1, 1): [np.full((1, 3, 2), 0.5)],
(0, 1): [np.full((1, 3, 4), -1)]}
res = _concatenate_chunks(chunks).compute(scheduler='single-threaded')
assert np.all(res[:5, :4] == 1.0)
assert np.all(res[:5, 4:] == 0.0)
assert np.all(res[5:, :4] == -1.0)
assert np.all(res[5:, 4:] == 0.5)
assert res.shape == (8, 6)
assert np.all(res[0, :5, :4] == 1.0)
assert np.all(res[0, :5, 4:] == 0.0)
assert np.all(res[0, 5:, :4] == -1.0)
assert np.all(res[0, 5:, 4:] == 0.5)
assert res.shape == (1, 8, 6)

# 3-band image
chunks = {(0, 0): [np.ones((3, 5, 4)), np.zeros((3, 5, 4))],
Expand All @@ -493,5 +503,4 @@ def test_concatenate_chunks_stack_calls(dask_da):
_ = _concatenate_chunks(chunks)
dask_da.stack.assert_called_once_with(chunks[(0, 0)], axis=-1)
dask_da.nanmax.assert_called_once()
assert 'axis=2' in str(dask_da.concatenate.mock_calls[-2])
assert 'squeeze' in str(dask_da.concatenate.mock_calls[-1])
assert 'axis=2' in str(dask_da.concatenate.mock_calls[-1])

0 comments on commit 1928289

Please sign in to comment.