Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Retries #232

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ style = [
test = [
"hypothesis<7.0.0,>=6.35.0",
"pytest<7.0.0,>=6.2.5",
"sat-stac>=0.4.1",
]
util = [
"py-spy",
Expand Down
5 changes: 3 additions & 2 deletions stackstac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .rio_env import LayeredEnv
from .rio_reader import DEFAULT_GDAL_ENV, MULTITHREADED_DRIVER_ALLOWLIST
from .stack import stack
from .stack import stack, DEFAULT_RETRY_ERRORS, DEFAULT_ERRORS_AS_NODATA
from .ops import mosaic
from .geom_utils import reproject_array, array_bounds, array_epsg, xyztile_of_array

Expand All @@ -13,7 +13,6 @@
msg = _traceback.format_exc()

def _missing_imports(*args, **kwargs):

raise ImportError(
"Optional dependencies for map visualization are missing.\n"
"Please re-install stackstac with the `viz` extra:\n"
Expand All @@ -34,6 +33,8 @@ def _missing_imports(*args, **kwargs):
__all__ = [
"LayeredEnv",
"DEFAULT_GDAL_ENV",
"DEFAULT_RETRY_ERRORS",
"DEFAULT_ERRORS_AS_NODATA",
"MULTITHREADED_DRIVER_ALLOWLIST",
"stack",
"show",
Expand Down
65 changes: 0 additions & 65 deletions stackstac/nodata_reader.py

This file was deleted.

11 changes: 2 additions & 9 deletions stackstac/reader_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Reader(Pickleable, Protocol):
"""
Protocol for a thread-safe, lazily-loaded object for reading data from a single-band STAC asset.
"""
url: str

def __init__(
self,
Expand All @@ -36,7 +37,6 @@ def __init__(
fill_value: Union[int, float],
scale_offset: Tuple[Union[int, float], Union[int, float]],
gdal_env: Optional[LayeredEnv],
errors_as_nodata: Tuple[Exception, ...] = (),
) -> None:
"""
Construct the Dataset *without* fetching any data.
Expand All @@ -58,14 +58,6 @@ def __init__(
gdal_env:
A `~.LayeredEnv` of GDAL configuration options to use while opening
and reading datasets. If None (default), `~.DEFAULT_GDAL_ENV` is used.
errors_as_nodata:
Exception patterns to ignore when opening datasets or reading data.
Exceptions matching the pattern will be logged as warnings, and just
produce nodata (``fill_value``).

The exception patterns should be instances of an Exception type to catch,
where ``str(exception_pattern)`` is a regex pattern to match against
``str(raised_exception)``.
"""
# TODO colormaps?

Expand Down Expand Up @@ -113,6 +105,7 @@ class FakeReader:

def __init__(self, *, dtype: np.dtype, **kwargs) -> None:
self.dtype = dtype
self.url = "fake"

def read(self, window: Window, **kwargs) -> np.ndarray:
return np.random.random((window.height, window.width)).astype(self.dtype)
Expand Down
56 changes: 17 additions & 39 deletions stackstac/rio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import threading
import warnings
from typing import TYPE_CHECKING, Optional, Protocol, Tuple, Type, TypedDict, Union

import numpy as np
Expand All @@ -13,7 +12,6 @@
from .timer import time
from .reader_protocol import Reader
from .raster_spec import RasterSpec
from .nodata_reader import NodataReader, exception_matches, nodata_for_window

if TYPE_CHECKING:
from rasterio.enums import Resampling
Expand Down Expand Up @@ -42,7 +40,7 @@ def _curthread():
open=dict(
GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
# ^ stop GDAL from requesting `.aux` and `.msk` files from the bucket (speeds up `open` time a lot)
VSI_CACHE=True
VSI_CACHE=True,
# ^ cache HTTP requests for opening datasets. This is critical for `ThreadLocalRioDataset`,
# which re-opens the same URL many times---having the request cached makes subsequent `open`s
# in different threads snappy.
Expand Down Expand Up @@ -283,7 +281,6 @@ class PickleState(TypedDict):
fill_value: Union[int, float]
scale_offset: Tuple[Union[int, float], Union[int, float]]
gdal_env: Optional[LayeredEnv]
errors_as_nodata: Tuple[Exception, ...]


class AutoParallelRioReader:
Expand All @@ -306,7 +303,6 @@ def __init__(
fill_value: Union[int, float],
scale_offset: Tuple[Union[int, float], Union[int, float]],
gdal_env: Optional[LayeredEnv] = None,
errors_as_nodata: Tuple[Exception, ...] = (),
) -> None:
self.url = url
self.spec = spec
Expand All @@ -315,25 +311,14 @@ def __init__(
self.fill_value = fill_value
self.scale_offset = scale_offset
self.gdal_env = gdal_env or DEFAULT_GDAL_ENV
self.errors_as_nodata = errors_as_nodata

self._dataset: Optional[ThreadsafeRioDataset] = None
self._dataset_lock = threading.Lock()

def _open(self) -> ThreadsafeRioDataset:
with self.gdal_env.open:
with time(f"Initial read for {self.url!r} on {_curthread()}: {{t}}"):
try:
ds = SelfCleaningDatasetReader(self.url, sharing=False)
except Exception as e:
msg = f"Error opening {self.url!r}: {e!r}"
if exception_matches(e, self.errors_as_nodata):
warnings.warn(msg)
return NodataReader(
dtype=self.dtype, fill_value=self.fill_value
)

raise RuntimeError(msg) from e
ds = SelfCleaningDatasetReader(self.url, sharing=False)
if ds.count != 1:
ds.close()
raise RuntimeError(
Expand Down Expand Up @@ -375,30 +360,22 @@ def _open(self) -> ThreadsafeRioDataset:
return SingleThreadedRioDataset(self.gdal_env, ds, vrt=vrt)

@property
def dataset(self):
def dataset(self) -> ThreadsafeRioDataset:
with self._dataset_lock:
if self._dataset is None:
self._dataset = self._open()
return self._dataset

def read(self, window: Window, **kwargs) -> np.ndarray:
reader = self.dataset
try:
result = reader.read(
window=window,
out_dtype=self.dtype,
masked=True,
# ^ NOTE: we always do a masked array, so we can safely apply scales and offsets
# without potentially altering pixels that should have been the ``fill_value``
**kwargs,
)
except Exception as e:
msg = f"Error reading {window} from {self.url!r}: {e!r}"
if exception_matches(e, self.errors_as_nodata):
warnings.warn(msg)
return nodata_for_window(window, self.fill_value, self.dtype)

raise RuntimeError(msg) from e
result = reader.read(
window=window,
out_dtype=self.dtype,
masked=True,
# ^ NOTE: we always do a masked array, so we can safely apply scales and offsets
# without potentially altering pixels that should have been the ``fill_value``
**kwargs,
)

# When the GeoTIFF doesn't have a nodata value, and we're using a VRT, pixels
# outside the dataset don't get properly masked (they're just 0). Using `add_alpha`
Expand All @@ -409,7 +386,9 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
elif result.shape[0] == 1:
result = result[0]
else:
raise RuntimeError(f"Unexpected shape {result.shape}, expected exactly 1 band.")
raise RuntimeError(
f"Unexpected shape {result.shape}, expected exactly 1 band."
)

scale, offset = self.scale_offset

Expand All @@ -419,9 +398,9 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
result += offset

result = np.ma.filled(result, fill_value=self.fill_value)
assert np.issubdtype(result.dtype, self.dtype), (
f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}"
)
assert np.issubdtype(
result.dtype, self.dtype
), f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}"
return result

def close(self) -> None:
Expand Down Expand Up @@ -451,7 +430,6 @@ def __getstate__(
"fill_value": self.fill_value,
"scale_offset": self.scale_offset,
"gdal_env": self.gdal_env,
"errors_as_nodata": self.errors_as_nodata,
}

def __setstate__(
Expand Down
Loading