Skip to content

Commit

Permalink
BeninSmallHolderCashews: return geospatial metadata in sample (#377)
Browse files Browse the repository at this point in the history
* exposing geoinformation

* restoring white spaces

* Returns tile_transform and crs
instead of storing it in self.
Adapt docstring

* adding types

* adjust line length

* fixing formatting

* remove auto-format from another auto-pep8 tool

* adjust order of import

* radjust autopep8

* passes pre-commit hook

* Add Oxford comma

* tile_transform -> transform

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
recursix and adamjstewart committed Feb 24, 2022
1 parent 0164104 commit 26b6917
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import rasterio
import rasterio.features
import torch
from rasterio.crs import CRS
from torch import Tensor

from .geo import VisionDataset
Expand Down Expand Up @@ -233,12 +234,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
index: index to return
Returns:
image, mask, and metadata at that index
a dict containing image, mask, transform, crs, and metadata at index.
"""
y, x = self.chips_metadata[index]

img = self._load_all_imagery(self.bands)
labels = self._load_mask()
img, transform, crs = self._load_all_imagery(self.bands)
labels = self._load_mask(transform)

img = img[:, :, y : y + self.chip_size, x : x + self.chip_size]
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
Expand All @@ -248,6 +249,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
"mask": labels,
"x": torch.tensor(x), # type: ignore[attr-defined]
"y": torch.tensor(y), # type: ignore[attr-defined]
"transform": transform,
"crs": crs,
}

if self.transforms is not None:
Expand Down Expand Up @@ -279,17 +282,21 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None:
raise ValueError(f"'{band}' is an invalid band name.")

@lru_cache(maxsize=128)
def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor:
def _load_all_imagery(
self, bands: Tuple[str, ...] = ALL_BANDS
) -> Tuple[Tensor, rasterio.Affine, CRS]:
"""Load all the imagery (across time) for the dataset.
Optionally allows for subsetting of the bands that are loaded.
Args:
bands: tuple of bands to load
Returns
imagery of shape (70, number of bands, 1186, 1122) where 70 is the number of
points in time, 1186 is the tile height, and 1122 is the tile width
Returns:
imagery of shape (70, number of bands, 1186, 1122) where 70 is the number
of points in time, 1186 is the tile height, and 1122 is the tile width
rasterio affine transform, mapping pixel coordinates to geo coordinates
coordinate reference system of transform
"""
if self.verbose:
print("Loading all imagery")
Expand All @@ -303,12 +310,15 @@ def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor:
)

for date_index, date in enumerate(self.dates):
img[date_index] = self._load_single_scene(date, self.bands)
single_scene, transform, crs = self._load_single_scene(date, self.bands)
img[date_index] = single_scene

return img
return img, transform, crs

@lru_cache(maxsize=128)
def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
def _load_single_scene(
self, date: str, bands: Tuple[str, ...]
) -> Tuple[Tensor, rasterio.Affine, CRS]:
"""Load the imagery for a single date.
Optionally allows for subsetting of the bands that are loaded.
Expand All @@ -318,7 +328,9 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
bands: bands to load
Returns:
tensor containing a single image tile
Tensor containing a single image tile, rasterio affine transform,
mapping pixel coordinates to geo coordinates, and coordinate
reference system of transform.
Raises:
AssertionError: if ``date`` is invalid
Expand All @@ -342,14 +354,15 @@ def _load_single_scene(self, date: str, bands: Tuple[str, ...]) -> Tensor:
f"{band_name}.tif",
)
with rasterio.open(filepath) as src:
self.tile_transform = src.transform
transform = src.transform # same transform for every bands
crs = src.crs
array = src.read().astype(np.float32)
img[band_index] = torch.from_numpy(array) # type: ignore[attr-defined]

return img
return img, transform, crs

@lru_cache()
def _load_mask(self) -> Tensor:
def _load_mask(self, transform: rasterio.Affine) -> Tensor:
"""Rasterizes the dataset's labels (in geojson format)."""
# Create a mask layer out of the geojson
mask_geojson_fn = os.path.join(
Expand All @@ -367,7 +380,7 @@ def _load_mask(self) -> Tensor:
labels,
out_shape=(self.tile_height, self.tile_width),
fill=0, # nodata value
transform=self.tile_transform,
transform=transform,
all_touched=False,
dtype=np.uint8,
)
Expand Down

0 comments on commit 26b6917

Please sign in to comment.