Skip to content

Commit

Permalink
Add plot method and data.py to Sentinel Dataset (#416)
Browse files Browse the repository at this point in the history
* add plot method and data.py

* Adding normalization for plot

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
nilsleh and calebrob6 authored Feb 21, 2022
1 parent 644fb05 commit 6bc3779
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 1 deletion.
61 changes: 61 additions & 0 deletions tests/data/sentinel2/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import random

import numpy as np
import rasterio

SIZE = 32

np.random.seed(0)
random.seed(0)

filenames = [
"T41XNE_20200829T083611_B01_60m.tif",
"T41XNE_20200829T083611_B02_10m.tif",
"T41XNE_20200829T083611_B03_10m.tif",
"T41XNE_20200829T083611_B04_10m.tif",
"T41XNE_20200829T083611_B05_20m.tif",
"T41XNE_20200829T083611_B06_20m.tif",
"T41XNE_20200829T083611_B07_20m.tif",
"T41XNE_20200829T083611_B08_10m.tif",
"T41XNE_20200829T083611_B8A_20m.tif",
"T41XNE_20200829T083611_B09_60m.tif",
"T41XNE_20200829T083611_B11_20m.tif",
]


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2

if "float" in profile["dtype"]:
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
else:
Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
)

src = rasterio.open(path, "w", **profile)
for i in range(1, profile["count"] + 1):
src.write(Z, i)


if __name__ == "__main__":
for f in filenames:
if os.path.exists(f):
os.remove(f)

create_file(path=f, dtype="int32", num_channels=1)
13 changes: 13 additions & 0 deletions tests/datasets/test_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):
Sentinel2(str(tmp_path))

def test_plot(self, dataset: Sentinel2) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")

def test_plot_wrong_bands(self, dataset: Sentinel2) -> None:
bands = ("B01",)
ds = Sentinel2(root=dataset.root, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds.plot(x)

def test_invalid_query(self, dataset: Sentinel2) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
Expand Down
49 changes: 48 additions & 1 deletion torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from typing import Any, Callable, Dict, Optional, Sequence

import matplotlib.pyplot as plt
import torch
from rasterio.crs import CRS
from torch import Tensor

from .geo import RasterDataset

Expand Down Expand Up @@ -66,7 +69,7 @@ class Sentinel2(Sentinel):
"B11",
"B12",
]
rgb_bands = ["B04", "B03", "B02"]
RGB_BANDS = ["B04", "B03", "B02"]

separate_files = True

Expand Down Expand Up @@ -98,3 +101,47 @@ def __init__(
self.bands = bands if bands else self.all_bands

super().__init__(root, crs, res, transforms, cache)

def plot( # type: ignore[override]
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if the RGB bands are not included in ``self.bands``
.. versionadded:: 0.3
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

image = sample["image"][rgb_indices].permute(1, 2, 0)
image = torch.clamp(image / 3000, min=0, max=1) # type: ignore[attr-defined]

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Image")

if suptitle is not None:
plt.suptitle(suptitle)

return fig

0 comments on commit 6bc3779

Please sign in to comment.