Skip to content

Commit

Permalink
Add plot method to CDL dataset (#415)
Browse files Browse the repository at this point in the history
* requested changes

* add colormap to file
  • Loading branch information
nilsleh authored Mar 1, 2022
1 parent c933694 commit 17792df
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 20 deletions.
Binary file modified tests/data/cdl/2020_30m_cdls.zip
Binary file not shown.
Binary file modified tests/data/cdl/2021_30m_cdls.zip
Binary file not shown.
19 changes: 16 additions & 3 deletions tests/data/cdl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,34 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2
cmap = {
0: (0, 0, 0, 0),
1: (255, 211, 0, 255),
2: (255, 38, 38, 255),
3: (0, 168, 228, 255),
4: (255, 158, 11, 255),
5: (38, 112, 0, 255),
6: (255, 255, 0, 255),
7: (0, 0, 0, 255),
8: (0, 0, 0, 255),
}

Z = np.random.randint(size=(SIZE, SIZE), low=0, high=8)

src = rasterio.open(path, "w", **profile)
for i in range(1, profile["count"] + 1):
src.write(Z, i)

src.write_colormap(1, cmap)


directories = ["2020_30m_cdls", "2021_30m_cdls"]
raster_extensions = [[".tif", ".tif.ovr"], [".tif", ".tif.ovr"]]
raster_extensions = [".tif", ".tif.ovr"]


if __name__ == "__main__":

for dir, extensions in zip(directories, raster_extensions):
for dir in directories:
filename = dir + ".zip"

# Remove old data
Expand All @@ -51,7 +64,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:

os.makedirs(os.path.join(os.getcwd(), dir))

for e in extensions:
for e in raster_extensions:
create_file(
os.path.join(dir, filename.replace(".zip", e)),
dtype="int8",
Expand Down
26 changes: 11 additions & 15 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,10 @@ def dataset(
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.cdl, "download_url", download_url
)
cmap = {
0: (0, 0, 0, 0),
1: (255, 211, 0, 255),
2: (255, 38, 38, 255),
3: (0, 168, 228, 255),
4: (255, 158, 11, 255),
5: (38, 112, 0, 255),
6: (255, 255, 0, 255),
7: (0, 0, 0, 255),
8: (0, 0, 0, 255),
}
monkeypatch.setattr(CDL, "cmap", cmap) # type: ignore[attr-defined]

md5s = [
(2021, "4618f054004110ea11b19541b4b9f734"),
(2020, "593a86e62e3dd44438d536dc2442c082"),
(2021, "e929beb9c8e59fa1d7b7f82e64edaae1"),
(2020, "e95c2d40ce0c261ed6ee0bd00b49e4b6"),
]
monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
Expand Down Expand Up @@ -90,7 +79,14 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
def test_plot(self, dataset: CDL) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x["mask"])
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: CDL) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_not_downloaded(self, tmp_path: Path) -> None:
Expand Down
Loading

0 comments on commit 17792df

Please sign in to comment.