diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 3d968de12db..abd8224a8e8 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -13,6 +13,7 @@ import rasterio import rasterio.features import torch +from rasterio.crs import CRS from torch import Tensor from .geo import VisionDataset @@ -233,12 +234,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: index: index to return Returns: - image, mask, and metadata at that index + a dict containing image, mask, transform, crs, and metadata at index. """ y, x = self.chips_metadata[index] - img = self._load_all_imagery(self.bands) - labels = self._load_mask() + img, transform, crs = self._load_all_imagery(self.bands) + labels = self._load_mask(transform) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] @@ -248,6 +249,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: "mask": labels, "x": torch.tensor(x), # type: ignore[attr-defined] "y": torch.tensor(y), # type: ignore[attr-defined] + "transform": transform, + "crs": crs, } if self.transforms is not None: @@ -279,7 +282,9 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None: raise ValueError(f"'{band}' is an invalid band name.") @lru_cache(maxsize=128) - def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor: + def _load_all_imagery( + self, bands: Tuple[str, ...] = ALL_BANDS + ) -> Tuple[Tensor, rasterio.Affine, CRS]: """Load all the imagery (across time) for the dataset. Optionally allows for subsetting of the bands that are loaded. @@ -287,9 +292,11 @@ def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor: Args: bands: tuple of bands to load - Returns - imagery of shape (70, number of bands, 1186, 1122) where 70 is the number of - points in time, 1186 is the tile height, and 1122 is the tile width + Returns: + imagery of shape (70, number of bands, 1186, 1122) where 70 is the number + of points in time, 1186 is the tile height, and 1122 is the tile width + rasterio affine transform, mapping pixel coordinates to geo coordinates + coordinate reference system of transform """ if self.verbose: print("Loading all imagery") @@ -303,12 +310,15 @@ def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor: ) for date_index, date in enumerate(self.dates): - img[date_index] = self._load_single_scene(date, self.bands) + single_scene, transform, crs = self._load_single_scene(date, self.bands) + img[date_index] = single_scene - return img + return img, transform, crs @lru_cache(maxsize=128) - def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor: + def _load_single_scene( + self, date: str, bands: Tuple[str, ...] + ) -> Tuple[Tensor, rasterio.Affine, CRS]: """Load the imagery for a single date. Optionally allows for subsetting of the bands that are loaded. @@ -318,7 +328,9 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor: bands: bands to load Returns: - tensor containing a single image tile + Tensor containing a single image tile, rasterio affine transform, + mapping pixel coordinates to geo coordinates, and coordinate + reference system of transform. Raises: AssertionError: if ``date`` is invalid @@ -342,14 +354,15 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor: f"{band_name}.tif", ) with rasterio.open(filepath) as src: - self.tile_transform = src.transform + transform = src.transform # same transform for every bands + crs = src.crs array = src.read().astype(np.float32) img[band_index] = torch.from_numpy(array) # type: ignore[attr-defined] - return img + return img, transform, crs @lru_cache() - def _load_mask(self) -> Tensor: + def _load_mask(self, transform: rasterio.Affine) -> Tensor: """Rasterizes the dataset's labels (in geojson format).""" # Create a mask layer out of the geojson mask_geojson_fn = os.path.join( @@ -367,7 +380,7 @@ def _load_mask(self) -> Tensor: labels, out_shape=(self.tile_height, self.tile_width), fill=0, # nodata value - transform=self.tile_transform, + transform=transform, all_touched=False, dtype=np.uint8, )