Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BeninSmallHolderCashews: return geospatial metadata in sample #377

Merged
merged 13 commits into from
Feb 24, 2022
45 changes: 30 additions & 15 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import rasterio
import rasterio.features
import torch
from rasterio.crs import CRS
from torch import Tensor

from .geo import VisionDataset
Expand Down Expand Up @@ -233,12 +234,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
index: index to return

Returns:
image, mask, and metadata at that index
a dict containing image, mask, tile_transform, crs, and metadata at index.
"""
y, x = self.chips_metadata[index]

img = self._load_all_imagery(self.bands)
labels = self._load_mask()
img, tile_transform, crs = self._load_all_imagery(self.bands)
labels = self._load_mask(tile_transform)

img = img[:, :, y : y + self.chip_size, x : x + self.chip_size]
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
Expand All @@ -248,6 +249,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
"mask": labels,
"x": torch.tensor(x), # type: ignore[attr-defined]
"y": torch.tensor(y), # type: ignore[attr-defined]
"tile_transform": tile_transform,
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"crs": crs,
}

if self.transforms is not None:
Expand Down Expand Up @@ -279,17 +282,21 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None:
raise ValueError(f"'{band}' is an invalid band name.")

@lru_cache(maxsize=128)
def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor:
def _load_all_imagery(
self, bands: Tuple[str, ...] = ALL_BANDS
) -> Tuple[Tensor, rasterio.Affine, CRS]:
"""Load all the imagery (across time) for the dataset.

Optionally allows for subsetting of the bands that are loaded.

Args:
bands: tuple of bands to load

Returns
imagery of shape (70, number of bands, 1186, 1122) where 70 is the number of
points in time, 1186 is the tile height, and 1122 is the tile width
Returns:
imagery of shape (70, number of bands, 1186, 1122) where 70 is the number
of points in time, 1186 is the tile height, and 1122 is the tile width
rasterio affine transform, mapping pixel coordinates to geo coordinates
coordinate reference system of tile_transform
"""
if self.verbose:
print("Loading all imagery")
Expand All @@ -303,12 +310,17 @@ def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor:
)

for date_index, date in enumerate(self.dates):
img[date_index] = self._load_single_scene(date, self.bands)
single_scene, tile_transform, crs = self._load_single_scene(
date, self.bands
)
img[date_index] = single_scene

return img
return img, tile_transform, crs

@lru_cache(maxsize=128)
def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
def _load_single_scene(
self, date: str, bands: Tuple[str, ...]
) -> Tuple[Tensor, rasterio.Affine, CRS]:
"""Load the imagery for a single date.

Optionally allows for subsetting of the bands that are loaded.
Expand All @@ -318,7 +330,9 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
bands: bands to load

Returns:
tensor containing a single image tile
Tensor containing a single image tile, rasterio affine transform,
mapping pixel coordinates to geo coordinates, and coordinate
reference system of tile_transform.

Raises:
AssertionError: if ``date`` is invalid
Expand All @@ -342,14 +356,15 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
f"{band_name}.tif",
)
with rasterio.open(filepath) as src:
self.tile_transform = src.transform
tile_transform = src.transform # same transform for every bands
crs = src.crs
array = src.read().astype(np.float32)
img[band_index] = torch.from_numpy(array) # type: ignore[attr-defined]

return img
return img, tile_transform, crs

@lru_cache()
def _load_mask(self) -> Tensor:
def _load_mask(self, tile_transform: rasterio.Affine) -> Tensor:
"""Rasterizes the dataset's labels (in geojson format)."""
# Create a mask layer out of the geojson
mask_geojson_fn = os.path.join(
Expand All @@ -367,7 +382,7 @@ def _load_mask(self) -> Tensor:
labels,
out_shape=(self.tile_height, self.tile_width),
fill=0, # nodata value
transform=self.tile_transform,
transform=tile_transform,
all_touched=False,
dtype=np.uint8,
)
Expand Down