From 336f89c9b57d027a384d0639e71f4780974f55f9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 9 Jul 2022 16:58:40 -0700 Subject: [PATCH 1/3] Landsat: add plot method --- tests/datasets/test_landsat.py | 15 ++++++++++ torchgeo/datasets/cdl.py | 4 +++ torchgeo/datasets/chesapeake.py | 7 +++-- torchgeo/datasets/landsat.py | 52 +++++++++++++++++++++++++++++++++ torchgeo/datasets/naip.py | 4 +-- torchgeo/datasets/sentinel.py | 7 +++-- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 59b6a0f61a4..bc1ff2c8aea 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -4,6 +4,7 @@ import os from pathlib import Path +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -38,6 +39,20 @@ def test_or(self, dataset: Landsat8) -> None: ds = dataset | dataset assert isinstance(ds, UnionDataset) + def test_plot(self, dataset: Landsat8) -> None: + x = dataset[dataset.bounds] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_wrong_bands(self, dataset: Landsat8) -> None: + bands = ("SR_B1",) + ds = Landsat8(root=dataset.root, bands=bands) + x = dataset[dataset.bounds] + with pytest.raises( + ValueError, match="Dataset doesn't contain some of the RGB bands" + ): + ds.plot(x) + def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "): Landsat8(str(tmp_path)) diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 9d9c5242564..de4023fc50b 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -413,6 +413,10 @@ def plot( Returns: a matplotlib Figure with the rendered sample + + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ mask = sample["mask"].squeeze().numpy() ncols = 1 diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7557136d277..59528f9c0fe 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -19,7 +19,6 @@ import torch from matplotlib.colors import ListedColormap from rasterio.crs import CRS -from torch import Tensor from .geo import GeoDataset, RasterDataset from .utils import BoundingBox, download_url, extract_archive @@ -178,7 +177,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: Dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -192,7 +191,9 @@ def plot( Returns: a matplotlib Figure with the rendered sample - .. versionadded:: 0.3 + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ mask = sample["mask"].squeeze(0) ncols = 1 diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 0f62f0201ad..724248c1410 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -6,6 +6,7 @@ import abc from typing import Any, Callable, Dict, Optional, Sequence +import matplotlib.pyplot as plt from rasterio.crs import CRS from .geo import RasterDataset @@ -78,6 +79,57 @@ def __init__( super().__init__(root, crs, res, transforms, cache) + def plot( + self, + sample: Dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are not included in ``self.bands`` + + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + image = sample["image"][rgb_indices].permute(1, 2, 0) + + # Stretch to the range of 2nd to 98th percentile + per02 = image.quantile(0.02) + per98 = image.quantile(0.98) + image = (image - per02) / (per98 - per02) + image = image.clamp(min=0, max=1) + + fig, ax = plt.subplots(1, 1, figsize=(4, 4)) + + ax.imshow(image) + ax.axis("off") + + if show_titles: + ax.set_title("Image") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig + class Landsat1(Landsat): """Landsat 1 Multispectral Scanner (MSS).""" diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index b3d054a492f..9274b830129 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -64,8 +64,8 @@ def plot( a matplotlib Figure with the rendered sample .. versionchanged:: 0.3 - Method now takes a sample dict, not a Tensor. Additionally, possible to - show subplot titles and/or use a custom suptitle. + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ image = sample["image"][0:3, :, :].permute(1, 2, 0) diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 1f347bd1495..cc8817a8c85 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -8,7 +8,6 @@ import matplotlib.pyplot as plt import torch from rasterio.crs import CRS -from torch import Tensor from .geo import RasterDataset @@ -104,7 +103,7 @@ def __init__( def plot( self, - sample: Dict[str, Tensor], + sample: Dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -121,7 +120,9 @@ def plot( Raises: ValueError: if the RGB bands are not included in ``self.bands`` - .. versionadded:: 0.3 + .. versionchanged:: 0.3 + Method now takes a sample dict, not a Tensor. Additionally, possible to + show subplot titles and/or use a custom suptitle. """ rgb_indices = [] for band in self.RGB_BANDS: From 8369bb7e82ff4e2d1eb320dad3f3967a07e2535d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 9 Jul 2022 17:09:45 -0700 Subject: [PATCH 2/3] torch.quantile requires float tensors --- torchgeo/datasets/landsat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 724248c1410..f0fa6bc62f5 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -109,7 +109,7 @@ def plot( else: raise ValueError("Dataset doesn't contain some of the RGB bands") - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample["image"][rgb_indices].permute(1, 2, 0).float() # Stretch to the range of 2nd to 98th percentile per02 = image.quantile(0.02) From 2da8be4b3ce8f17173348f85574c7b383baaf756 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 9 Jul 2022 18:19:47 -0700 Subject: [PATCH 3/3] Use full stretch, quantile is limited --- docs/api/geo_datasets.csv | 2 +- torchgeo/datasets/landsat.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index f420f70cfea..9bee384cd95 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -12,7 +12,7 @@ Dataset,Type,Source,Size (px),Resolution (m) `GBIF`_,Points,Citizen Scientists,-,- `GlobBiomass`_,Masks,Landsat,"45,000x45,000",100 `iNaturalist`_,Points,Citizen Scientists,-,- -`Landsat`_,Imagery,Landsat,-,30 +`Landsat`_,Imagery,Landsat,"8,900x8,900",30 `NAIP`_,Imagery,Aerial,"6,100x7,600",1 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- `Sentinel`_,Imagery,Sentinel,"10,000x10,000",10 diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index f0fa6bc62f5..e3cc15d6faa 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -111,11 +111,8 @@ def plot( image = sample["image"][rgb_indices].permute(1, 2, 0).float() - # Stretch to the range of 2nd to 98th percentile - per02 = image.quantile(0.02) - per98 = image.quantile(0.98) - image = (image - per02) / (per98 - per02) - image = image.clamp(min=0, max=1) + # Stretch to the full range + image = (image - image.min()) / (image.max() - image.min()) fig, ax = plt.subplots(1, 1, figsize=(4, 4))