Skip to content

Commit

Permalink
Fix #451 and improve tests for downsampling (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
adehecq authored Jan 26, 2024
1 parent 3b5ddc8 commit c9c9a84
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
10 changes: 8 additions & 2 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,10 @@ def count(self) -> int:
def height(self) -> int:
"""Height of the raster in pixels."""
if not self.is_loaded:
return self._disk_shape[1] # type: ignore
if self._out_shape is not None:
return self._out_shape[0]
else:
return self._disk_shape[1] # type: ignore
else:
# If the raster is single-band
if len(self.data.shape) == 2:
Expand All @@ -562,7 +565,10 @@ def height(self) -> int:
def width(self) -> int:
"""Width of the raster in pixels."""
if not self.is_loaded:
return self._disk_shape[2] # type: ignore
if self._out_shape is not None:
return self._out_shape[1]
else:
return self._disk_shape[2] # type: ignore
else:
# If the raster is single-band
if len(self.data.shape) == 2:
Expand Down
87 changes: 63 additions & 24 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,33 +611,72 @@ def test_get_nanarray(self, example: str) -> None:
mask = ~mask
assert rst.raster_equal(rst_copy)

def test_downsampling(self) -> None:
@pytest.mark.parametrize("example", [aster_dem_path, landsat_b4_path, landsat_rgb_path]) # type: ignore
def test_downsampling(self, example: str) -> None:
"""
Check that self.data is correct when using downsampling
Check that self metadata are correctly updated when using downsampling
"""
# Test single band
r = gu.Raster(self.landsat_b4_path, downsample=4)
assert r.data.shape == (164, 200)
assert r.height == 164
assert r.width == 200

# Test multiple band
r = gu.Raster(self.landsat_rgb_path, downsample=2)
assert r.data.shape == (3, 328, 400)

# Test that xy2ij are consistent with new image
# Upper left
assert r.xy2ij(r.bounds.left, r.bounds.top) == (0, 0)
# Upper right
assert r.xy2ij(r.bounds.right + r.res[0], r.bounds.top) == (0, r.width + 1)
# Bottom right
assert r.xy2ij(r.bounds.right + r.res[0], r.bounds.bottom) == (r.height, r.width + 1)
# One pixel right and down
assert r.xy2ij(r.bounds.left + r.res[0], r.bounds.top - r.res[1]) == (1, 1)

# Check that error is raised when downsampling value is not valid
# Load raster at full resolution
rst_orig = gu.Raster(example)
bounds_orig = rst_orig.bounds

# -- Tries various downsampling factors to ensure rounding is needed in at least one case --
for down_fact in np.arange(2, 8):
rst_down = gu.Raster(example, downsample=int(down_fact))

# - Check that output resolution is as intended -
assert rst_down.res[0] == rst_orig.res[0] * down_fact
assert rst_down.res[1] == rst_orig.res[1] * down_fact

# - Check that downsampled width and height are as intended -
# Due to rounding, width/height can be up to 1 pixel larger than unrounded value
assert abs(rst_down.width - rst_orig.width / down_fact) < 1
assert abs(rst_down.height - rst_orig.height / down_fact) < 1
assert rst_down.shape == (rst_down.height, rst_down.width)

# - Check that bounds are updated accordingly -
# left/top bounds should be the same, right/bottom should be rounded to nearest pixel
bounds_down = rst_down.bounds
assert bounds_down.left == bounds_orig.left
assert bounds_down.top == bounds_orig.top
assert abs(bounds_down.right - bounds_orig.right) < rst_down.res[0]
assert abs(bounds_down.bottom - bounds_orig.bottom) < rst_down.res[1]

# - Check that metadata are consistent, with/out loading data -
assert not rst_down.is_loaded
width_unload = rst_down.width
height_unload = rst_down.height
bounds_unload = rst_down.bounds
rst_down.load()
width_load = rst_down.width
height_load = rst_down.height
bounds_load = rst_down.bounds
assert width_load == width_unload
assert height_load == height_unload
assert bounds_load == bounds_unload

# - Test that xy2ij are consistent with new image -
# Upper left
assert rst_down.xy2ij(rst_down.bounds.left, rst_down.bounds.top) == (0, 0)
# Upper right
assert rst_down.xy2ij(rst_down.bounds.right + rst_down.res[0], rst_down.bounds.top) == (
0,
rst_down.width + 1,
)
# Bottom right
assert rst_down.xy2ij(rst_down.bounds.right + rst_down.res[0], rst_down.bounds.bottom) == (
rst_down.height,
rst_down.width + 1,
)
# One pixel right and down
assert rst_down.xy2ij(rst_down.bounds.left + rst_down.res[0], rst_down.bounds.top - rst_down.res[1]) == (
1,
1,
)

# -- Check that error is raised when downsampling value is not valid --
with pytest.raises(TypeError, match="downsample must be of type int or float."):
gu.Raster(self.landsat_b4_path, downsample=[1, 1]) # type: ignore
gu.Raster(example, downsample=[1, 1]) # type: ignore

def test_add_sub(self) -> None:
"""
Expand Down

0 comments on commit c9c9a84

Please sign in to comment.