diff --git a/tests/data/esri2020/data.py b/tests/data/esri2020/data.py new file mode 100644 index 00000000000..9c3b110ad2d --- /dev/null +++ b/tests/data/esri2020/data.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random +import shutil + +import numpy as np +import rasterio + +np.random.seed(0) +random.seed(0) + +SIZE = 64 + + +files = [{"image": "N00E020_agb.tif"}, {"image": "N00E020_agb_err.tif"}] + + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + profile["compress"] = "lzw" + profile["predictor"] = 2 + + Z = np.random.randint( + np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + ) + src = rasterio.open(path, "w", **profile) + src.write(Z) + + +if __name__ == "__main__": + dir = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01" + tif_name = "00A_20200101-20210101.tif" + + if os.path.exists(dir): + shutil.rmtree(dir) + + os.makedirs(dir) + + # Create mask file + create_file(os.path.join(dir, tif_name), dtype="int8", num_channels=1) + + shutil.make_archive(dir, "zip", base_dir=dir) + + # Compute checksums + zipfilename = dir + ".zip" + with open(zipfilename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{zipfilename}: {md5}") diff --git a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip index 217161b7a15..910c3c23d23 100644 Binary files a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip and b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip differ diff --git a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif index 669e3f47342..293dcd47e44 100644 Binary files a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif and b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif differ diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 13a58e021e4..42eed144a95 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -32,7 +32,7 @@ def dataset( zipfile = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" monkeypatch.setattr(Esri2020, "zipfile", zipfile) # type: ignore[attr-defined] - md5 = "4932855fcd00735a34b74b1f87db3df0" + md5 = "34aec55538694171c7b605b0cc0d0138" monkeypatch.setattr(Esri2020, "md5", md5) # type: ignore[attr-defined] url = os.path.join( "tests", @@ -45,17 +45,6 @@ def dataset( transforms = nn.Identity() # type: ignore[attr-defined] return Esri2020(root, transforms=transforms, download=True, checksum=True) - def test_already_downloaded(self, tmp_path: Path) -> None: - url = os.path.join( - "tests", - "data", - "esri2020", - "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", - ) - root = str(tmp_path) - shutil.copy(url, root) - Esri2020(root) - def test_getitem(self, dataset: Esri2020) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) @@ -65,6 +54,16 @@ def test_getitem(self, dataset: Esri2020) -> None: def test_already_extracted(self, dataset: Esri2020) -> None: Esri2020(root=dataset.root, download=True) + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join( + "tests", + "data", + "esri2020", + "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", + ) + shutil.copy(url, tmp_path) + Esri2020(root=str(tmp_path)) + def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): Esri2020(str(tmp_path), checksum=True) @@ -80,7 +79,14 @@ def test_or(self, dataset: Esri2020) -> None: def test_plot(self, dataset: Esri2020) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x["mask"]) + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: Esri2020) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") plt.close() def test_url(self) -> None: diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 4d8c651a6a0..78ed789a795 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -3,18 +3,18 @@ """Esri 2020 Land Cover Dataset.""" -import abc import glob import os from typing import Any, Callable, Dict, Optional +import matplotlib.pyplot as plt from rasterio.crs import CRS from .geo import RasterDataset from .utils import download_url, extract_archive -class Esri2020(RasterDataset, abc.ABC): +class Esri2020(RasterDataset): """Esri 2020 Land Cover Dataset. The `Esri 2020 Land Cover dataset @@ -136,3 +136,48 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.zipfile)) + + def plot( # type: ignore[override] + self, + sample: Dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze() + ncols = 1 + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = sample["prediction"].squeeze() + ncols = 2 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) + + if showing_predictions: + axs[0].imshow(mask) + axs[0].axis("off") + axs[1].imshow(prediction) + axs[1].axis("off") + if show_titles: + axs[0].set_title("Mask") + axs[1].set_title("Prediction") + else: + axs.imshow(mask) + axs.axis("off") + if show_titles: + axs.set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig