diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 881c3bab875..292018aa8eb 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -31,6 +31,20 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.cdl, "download_url", download_url ) + + cmap = { + 0: (0, 0, 0, 0), + 1: (255, 211, 0, 255), + 2: (255, 38, 38, 255), + 3: (0, 168, 228, 255), + 4: (255, 158, 11, 255), + 5: (38, 112, 0, 255), + 6: (255, 255, 0, 255), + 7: (0, 0, 0, 255), + 8: (0, 0, 0, 255), + } + monkeypatch.setattr(CDL, "cmap", cmap) # type: ignore[attr-defined] + md5s = [ (2021, "0693f0bb10deb79c69bcafe4aa1635b7"), (2020, "7695292902a8672d16ac034d4d560d84"), diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 86d53536d23..1f75f8e3afd 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -10,7 +10,6 @@ import matplotlib.pyplot as plt import numpy as np from rasterio.crs import CRS -from torch import Tensor from .geo import RasterDataset from .utils import download_url, extract_archive @@ -401,7 +400,7 @@ def _extract(self) -> None: def plot( # type: ignore[override] self, - sample: Dict[str, Tensor], + sample: Dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: