diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index d93445bb..e5a40732 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -977,7 +977,7 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb self._data[:, ind] = assign # type: ignore return None - def raster_equal(self, other: object) -> bool: + def raster_equal(self, other: RasterType) -> bool: """ Check if two rasters are equal. @@ -986,12 +986,23 @@ def raster_equal(self, other: object) -> bool: - The raster's transform, crs and nodata values. """ + # If the mask is just "False", it is equivalent to being equal to an array of False + if isinstance(self.data.mask, np.bool_): + self_mask = np.zeros(np.shape(self.data), dtype=bool) + else: + self_mask = self.data.mask + + if isinstance(other.data.mask, np.bool_): + other_mask = np.zeros(np.shape(other.data), dtype=bool) + else: + other_mask = other.data.mask + if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage? raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.") return all( [ np.array_equal(self.data.data, other.data.data, equal_nan=True), - np.array_equal(self.data.mask, other.data.mask), + np.array_equal(self_mask, other_mask), self.data.fill_value == other.data.fill_value, self.data.dtype == other.data.dtype, self.transform == other.transform, @@ -3585,14 +3596,14 @@ def __init__( ) self._data = self.data[0, :, :] - # Convert masked array to boolean - self._data = self.data.astype(bool) # type: ignore + # Force dtypes + self._dtypes = (bool,) # Fix nodata to None self._nodata = None - # Define in dtypes - self._dtypes = (bool,) + # Convert masked array to boolean + self._data = self.data.astype(bool) # type: ignore def __repr__(self) -> str: """Convert mask to string representation.""" diff --git a/tests/test_raster.py b/tests/test_raster.py index e92b76bd..dbb43f5c 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -2967,12 +2967,15 @@ def test_reproject(self, mask: gu.Mask) -> None: # Test 1: with a classic resampling (bilinear) # Reproject mask - resample to 100 x 100 grid + mask_orig = mask.copy() mask_reproj = mask.reproject(grid_size=(100, 100), force_source_nodata=2) # Check instance is respected assert isinstance(mask_reproj, gu.Mask) # Check the dtype of the original mask was properly reconverted assert mask.data.dtype == bool + # Check the original mask was not modified during reprojection + assert mask_orig.raster_equal(mask) # Check inplace behaviour works mask_tmp = mask.copy() @@ -2998,6 +3001,8 @@ def test_reproject(self, mask: gu.Mask) -> None: @pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore def test_crop(self, mask: gu.Mask) -> None: # Test with same bounds -> should be the same # + + mask_orig = mask.copy() crop_geom = mask.bounds mask_cropped = mask.crop(crop_geom) assert mask_cropped.raster_equal(mask) @@ -3006,6 +3011,8 @@ def test_crop(self, mask: gu.Mask) -> None: assert isinstance(mask_cropped, gu.Mask) # Check the dtype of the original mask was properly reconverted assert mask.data.dtype == bool + # Check the original mask was not modified during cropping + assert mask_orig.raster_equal(mask) # Check inplace behaviour works mask_tmp = mask.copy() @@ -3061,10 +3068,14 @@ def test_crop(self, mask: gu.Mask) -> None: @pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore def test_polygonize(self, mask: gu.Mask) -> None: + + mask_orig = mask.copy() # Run default vect = mask.polygonize() # Check the dtype of the original mask was properly reconverted assert mask.data.dtype == bool + # Check the original mask was not modified during polygonizing + assert mask_orig.raster_equal(mask) # Check the output is cast into a vector assert isinstance(vect, gu.Vector) @@ -3079,10 +3090,14 @@ def test_polygonize(self, mask: gu.Mask) -> None: @pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore def test_proximity(self, mask: gu.Mask) -> None: + + mask_orig = mask.copy() # Run default rast = mask.proximity() # Check the dtype of the original mask was properly reconverted assert mask.data.dtype == bool + # Check the original mask was not modified during reprojection + assert mask_orig.raster_equal(mask) # Check that output is cast back into a raster assert isinstance(rast, gu.Raster) @@ -3101,17 +3116,12 @@ def test_save(self, mask: gu.Mask) -> None: mask.save(temp_file) saved = gu.Mask(temp_file) - # TODO: Generalize raster_equal for masks? + # A raster (or mask) in-memory has more information than on disk, we need to update it before checking equality + # The values in its .data.data that are masked in .data.mask are not necessarily equal to the nodata value + mask.data.data[mask.data.mask] = True # The default nodata 255 is converted to boolean True on masked values # Check all attributes are equal - assert all( - [ - np.ma.allequal(saved.data, mask.data), - saved.transform == mask.transform, - saved.crs == mask.crs, - saved.nodata == mask.nodata, - ] - ) + assert mask.raster_equal(saved) # Clean up temporary folder - fails on Windows try: