Skip to content

Commit

Permalink
Make raster_equal accept False mask values of masked arrays (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet authored Feb 4, 2024
1 parent cde8ea6 commit 4411f1b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
23 changes: 17 additions & 6 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb
self._data[:, ind] = assign # type: ignore
return None

def raster_equal(self, other: object) -> bool:
def raster_equal(self, other: RasterType) -> bool:
"""
Check if two rasters are equal.
Expand All @@ -986,12 +986,23 @@ def raster_equal(self, other: object) -> bool:
- The raster's transform, crs and nodata values.
"""

# If the mask is just "False", it is equivalent to being equal to an array of False
if isinstance(self.data.mask, np.bool_):
self_mask = np.zeros(np.shape(self.data), dtype=bool)
else:
self_mask = self.data.mask

if isinstance(other.data.mask, np.bool_):
other_mask = np.zeros(np.shape(other.data), dtype=bool)
else:
other_mask = other.data.mask

if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage?
raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.")
return all(
[
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self.data.mask, other.data.mask),
np.array_equal(self_mask, other_mask),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
Expand Down Expand Up @@ -3585,14 +3596,14 @@ def __init__(
)
self._data = self.data[0, :, :]

# Convert masked array to boolean
self._data = self.data.astype(bool) # type: ignore
# Force dtypes
self._dtypes = (bool,)

# Fix nodata to None
self._nodata = None

# Define in dtypes
self._dtypes = (bool,)
# Convert masked array to boolean
self._data = self.data.astype(bool) # type: ignore

def __repr__(self) -> str:
"""Convert mask to string representation."""
Expand Down
28 changes: 19 additions & 9 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2967,12 +2967,15 @@ def test_reproject(self, mask: gu.Mask) -> None:
# Test 1: with a classic resampling (bilinear)

# Reproject mask - resample to 100 x 100 grid
mask_orig = mask.copy()
mask_reproj = mask.reproject(grid_size=(100, 100), force_source_nodata=2)

# Check instance is respected
assert isinstance(mask_reproj, gu.Mask)
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during reprojection
assert mask_orig.raster_equal(mask)

# Check inplace behaviour works
mask_tmp = mask.copy()
Expand All @@ -2998,6 +3001,8 @@ def test_reproject(self, mask: gu.Mask) -> None:
@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_crop(self, mask: gu.Mask) -> None:
# Test with same bounds -> should be the same #

mask_orig = mask.copy()
crop_geom = mask.bounds
mask_cropped = mask.crop(crop_geom)
assert mask_cropped.raster_equal(mask)
Expand All @@ -3006,6 +3011,8 @@ def test_crop(self, mask: gu.Mask) -> None:
assert isinstance(mask_cropped, gu.Mask)
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during cropping
assert mask_orig.raster_equal(mask)

# Check inplace behaviour works
mask_tmp = mask.copy()
Expand Down Expand Up @@ -3061,10 +3068,14 @@ def test_crop(self, mask: gu.Mask) -> None:

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_polygonize(self, mask: gu.Mask) -> None:

mask_orig = mask.copy()
# Run default
vect = mask.polygonize()
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during polygonizing
assert mask_orig.raster_equal(mask)

# Check the output is cast into a vector
assert isinstance(vect, gu.Vector)
Expand All @@ -3079,10 +3090,14 @@ def test_polygonize(self, mask: gu.Mask) -> None:

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_proximity(self, mask: gu.Mask) -> None:

mask_orig = mask.copy()
# Run default
rast = mask.proximity()
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during reprojection
assert mask_orig.raster_equal(mask)

# Check that output is cast back into a raster
assert isinstance(rast, gu.Raster)