Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unexpected shape error with errors_as_nodata #256

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions stackstac/nodata_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
State = Tuple[np.dtype, Union[int, float]]


# NOTE: this really should be a `ThreadsafeRioDataset` in `rio_reader.py`,
# not a `Reader` (it's never used as one).
class NodataReader:
"Reader that returns a constant (nodata) value for all reads"

scale_offset = (1.0, 0.0)

def __init__(
Expand Down
24 changes: 14 additions & 10 deletions stackstac/rio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _curthread():
open=dict(
GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
# ^ stop GDAL from requesting `.aux` and `.msk` files from the bucket (speeds up `open` time a lot)
VSI_CACHE=True
VSI_CACHE=True,
# ^ cache HTTP requests for opening datasets. This is critical for `ThreadLocalRioDataset`,
# which re-opens the same URL many times---having the request cached makes subsequent `open`s
# in different threads snappy.
Expand Down Expand Up @@ -70,11 +70,9 @@ def _curthread():
class ThreadsafeRioDataset(Protocol):
scale_offset: Tuple[Union[int, float], Union[int, float]]

def read(self, window: Window, **kwargs) -> np.ndarray:
...
def read(self, window: Window, **kwargs) -> np.ndarray: ...

def close(self) -> None:
...
def close(self) -> None: ...


class SingleThreadedRioDataset:
Expand Down Expand Up @@ -408,8 +406,14 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
result = np.ma.masked_array(result[0], mask=result[1] == 0)
elif result.shape[0] == 1:
result = result[0]
else:
raise RuntimeError(f"Unexpected shape {result.shape}, expected exactly 1 band.")
elif result.ndim != 2:
# We should only be getting `result.ndim == 2` in the case when `_open` produced a `NodataReader`.
# `Reader`s always return 2D arrays, whereas `rasterio.read` returns 3D. Pedantically, `NodataReader`
# shouldn't be a `Reader`, but a `ThreadsafeRioDataset`, and it should return a 3D array,
# just to be more consistent.
raise RuntimeError(
f"Unexpected shape {result.shape}, expected exactly 1 band."
)

scale, offset = self.scale_offset

Expand All @@ -419,9 +423,9 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
result += offset

result = np.ma.filled(result, fill_value=self.fill_value)
assert np.issubdtype(result.dtype, self.dtype), (
f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}"
)
assert np.issubdtype(
result.dtype, self.dtype
), f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}"
return result

def close(self) -> None:
Expand Down