diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index 5b72c768b..2e3b005ba 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -271,6 +271,9 @@ def compute(self, data, fill_value=None, **kwargs): if fill_value is not None: res = da.where(np.isnan(res), fill_value, res) + if res.ndim > len(data_dims): + res = res.squeeze() + res = xr.DataArray(res, dims=data_dims, coords=coords) return res @@ -400,6 +403,6 @@ def _concatenate_chunks(chunks): prev_y = y res.append(da.concatenate(col, axis=1)) - res = da.concatenate(res, axis=2).squeeze() + res = da.concatenate(res, axis=2) return res diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index cbe788965..27b866eb9 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -228,6 +228,16 @@ def test_resample_area_to_area_3d(self): assert np.allclose(res[1, :, :], 2.0) assert np.allclose(res[2, :, :], 3.0) + def test_resample_area_to_area_3d_single_channel(self): + """Resample area to area, 3d with only a single band.""" + data = xr.DataArray(da.ones((1, ) + self.src_area.shape, + dtype=np.float64), + dims=['bands', 'y', 'x']) + res = self.resampler.compute( + data, method='bil').compute(scheduler='single-threaded') + assert res.shape == (1, ) + self.dst_area.shape + assert np.allclose(res[0, :, :], 1.0) + def test_resample_swath_to_area_2d(self): """Resample swath to area, 2d.""" data = xr.DataArray(da.ones(self.src_swath.shape, dtype=np.float64), @@ -462,11 +472,11 @@ def test_concatenate_chunks(): (1, 1): [np.full((1, 3, 2), 0.5)], (0, 1): [np.full((1, 3, 4), -1)]} res = _concatenate_chunks(chunks).compute(scheduler='single-threaded') - assert np.all(res[:5, :4] == 1.0) - assert np.all(res[:5, 4:] == 0.0) - assert np.all(res[5:, :4] == -1.0) - assert np.all(res[5:, 4:] == 0.5) - assert res.shape == (8, 6) + assert np.all(res[0, :5, :4] == 1.0) + assert np.all(res[0, :5, 4:] == 0.0) + assert np.all(res[0, 5:, :4] == -1.0) + assert np.all(res[0, 5:, 4:] == 0.5) + assert res.shape == (1, 8, 6) # 3-band image chunks = {(0, 0): [np.ones((3, 5, 4)), np.zeros((3, 5, 4))], @@ -493,5 +503,4 @@ def test_concatenate_chunks_stack_calls(dask_da): _ = _concatenate_chunks(chunks) dask_da.stack.assert_called_once_with(chunks[(0, 0)], axis=-1) dask_da.nanmax.assert_called_once() - assert 'axis=2' in str(dask_da.concatenate.mock_calls[-2]) - assert 'squeeze' in str(dask_da.concatenate.mock_calls[-1]) + assert 'axis=2' in str(dask_da.concatenate.mock_calls[-1])