forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* dataset file no test * add test and data.py * Update eudem.py * requested changes * Update torchgeo/datasets/eudem.py * Apply suggestions from code review * rST fix Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
28575bf
commit e27f8fe
Showing
6 changed files
with
337 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import random | ||
import zipfile | ||
|
||
import numpy as np | ||
import rasterio | ||
|
||
np.random.seed(0) | ||
random.seed(0) | ||
|
||
SIZE = 64 | ||
|
||
files = [{"image": "eu_dem_v11_E30N10.TIF"}, {"image": "eu_dem_v11_E30N10.TIF.ovr"}] | ||
|
||
|
||
def create_file(path: str, dtype: str, num_channels: int) -> None: | ||
profile = {} | ||
profile["driver"] = "GTiff" | ||
profile["dtype"] = dtype | ||
profile["count"] = num_channels | ||
profile["crs"] = "epsg:4326" | ||
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) | ||
profile["height"] = SIZE | ||
profile["width"] = SIZE | ||
profile["compress"] = "lzw" | ||
profile["predictor"] = 2 | ||
|
||
Z = np.random.randint( | ||
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] | ||
) | ||
src = rasterio.open(path, "w", **profile) | ||
src.write(Z) | ||
|
||
|
||
if __name__ == "__main__": | ||
zipfilename = "eu_dem_v11_E30N10.zip" | ||
files_to_zip = [] | ||
|
||
for file_dict in files: | ||
path = file_dict["image"] | ||
# remove old data | ||
if os.path.exists(path): | ||
os.remove(path) | ||
# Create mask file | ||
create_file(path, dtype="int32", num_channels=1) | ||
files_to_zip.append(path) | ||
|
||
# Compress data | ||
with zipfile.ZipFile(zipfilename, "w") as zip: | ||
for file in files_to_zip: | ||
zip.write(file, arcname=file) | ||
|
||
# Compute checksums | ||
with open(zipfilename, "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f"{zipfilename}: {md5}") | ||
|
||
# remove TIF files | ||
for file_dict in files: | ||
os.remove(file_dict["image"]) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
from pathlib import Path | ||
from typing import Generator | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import EUDEM, BoundingBox, IntersectionDataset, UnionDataset | ||
|
||
|
||
class TestEUDEM: | ||
@pytest.fixture | ||
def dataset( | ||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path | ||
) -> EUDEM: | ||
|
||
md5s = {"eu_dem_v11_E30N10.zip": "ef148466c02197a08be169eaad186591"} | ||
monkeypatch.setattr(EUDEM, "md5s", md5s) # type: ignore[attr-defined] | ||
zipfile = os.path.join("tests", "data", "eudem", "eu_dem_v11_E30N10.zip") | ||
shutil.copy(zipfile, tmp_path) | ||
root = str(tmp_path) | ||
transforms = nn.Identity() # type: ignore[attr-defined] | ||
return EUDEM(root, transforms=transforms) | ||
|
||
def test_getitem(self, dataset: EUDEM) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["crs"], CRS) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_extracted_already(self, dataset: EUDEM) -> None: | ||
zipfile = os.path.join(dataset.root, "eu_dem_v11_E30N10.zip") | ||
shutil.unpack_archive(zipfile, dataset.root, "zip") | ||
EUDEM(dataset.root) | ||
|
||
def test_no_dataset(self, tmp_path: Path) -> None: | ||
shutil.rmtree(tmp_path) | ||
os.makedirs(tmp_path) | ||
with pytest.raises(RuntimeError, match="Dataset not found in"): | ||
EUDEM(root=str(tmp_path)) | ||
|
||
def test_corrupted(self, tmp_path: Path) -> None: | ||
with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f: | ||
f.write("bad") | ||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): | ||
EUDEM(root=str(tmp_path), checksum=True) | ||
|
||
def test_and(self, dataset: EUDEM) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: EUDEM) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_plot(self, dataset: EUDEM) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
dataset.plot(x, suptitle="Test") | ||
|
||
def test_plot_prediction(self, dataset: EUDEM) -> None: | ||
query = dataset.bounds | ||
x = dataset[query] | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x, suptitle="Prediction") | ||
|
||
def test_invalid_query(self, dataset: EUDEM) -> None: | ||
query = BoundingBox(100, 100, 100, 100, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""European Digital Elevation Model (EU-DEM) dataset.""" | ||
|
||
import glob | ||
import os | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import matplotlib.pyplot as plt | ||
from rasterio.crs import CRS | ||
|
||
from .geo import RasterDataset | ||
from .utils import check_integrity, extract_archive | ||
|
||
|
||
class EUDEM(RasterDataset): | ||
"""European Digital Elevation Model (EU-DEM) Dataset. | ||
The `EU-DEM | ||
<https://land.copernicus.eu/imagery-in-situ/eu-dem/eu-dem-v1.1?tab=mapview>`_ | ||
dataset is a Digital Elevation Model of reference for the entire European region. | ||
The dataset can be downloaded from this `website | ||
<https://land.copernicus.eu/imagery-in-situ/eu-dem/eu-dem-v1.1?tab=mapview>`_ | ||
after making an account. A dataset factsheet is available | ||
`here <https://land.copernicus.eu/user-corner/publications/eu-dem-flyer/view>`__. | ||
Dataset features: | ||
* DEMs at 25 m per pixel spatial resolution (~40,000x40,0000 px) | ||
* vertical accuracy of +/- 7 m RMSE | ||
* data fused from `ASTER GDEM | ||
<https://lpdaac.usgs.gov/news/nasa-and-meti-release-aster-global-dem-version-3/>`_, | ||
`SRTM <https://www2.jpl.nasa.gov/srtm/>`_ and Russian topomaps | ||
Dataset format: | ||
* DEMs are single-channel tif files | ||
If you use this dataset in your research, please give credit to: | ||
* `Copernicus <https://land.copernicus.eu/imagery-in-situ/eu-dem/eu-dem-v1.1>`_ | ||
.. versionadded:: 0.3 | ||
""" | ||
|
||
is_image = False | ||
filename_glob = "eu_dem_v11_*.TIF" | ||
zipfile_glob = "eu_dem_v11_*[A-Z0-9].zip" | ||
filename_regex = "(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})" | ||
|
||
md5s = { | ||
"eu_dem_v11_E00N20.zip": "96edc7e11bc299b994e848050d6be591", | ||
"eu_dem_v11_E10N00.zip": "e14be147ac83eddf655f4833d55c1571", | ||
"eu_dem_v11_E10N10.zip": "2eb5187e4d827245b33768404529c709", | ||
"eu_dem_v11_E10N20.zip": "1afc162eb131841aed0d00b692b870a8", | ||
"eu_dem_v11_E20N10.zip": "77b040791b9fb7de271b3f47130b4e0c", | ||
"eu_dem_v11_E20N20.zip": "89b965abdcb1dbd479c61117f55230c8", | ||
"eu_dem_v11_E20N30.zip": "f5cb1b05813ae8ffc9e70f0ad56cc372", | ||
"eu_dem_v11_E20N40.zip": "81be551ff646802d7d820385de7476e9", | ||
"eu_dem_v11_E20N50.zip": "bbc351713ea3eb7e9eb6794acb9e4bc8", | ||
"eu_dem_v11_E30N10.zip": "68fb95aac33a025c4f35571f32f237ff", | ||
"eu_dem_v11_E30N20.zip": "da8ad029f9cc1ec9234ea3e7629fe18d", | ||
"eu_dem_v11_E30N30.zip": "de27c78d0176e45aec5c9e462a95749c", | ||
"eu_dem_v11_E30N40.zip": "4c00e58b624adfc4a5748c922e77ee40", | ||
"eu_dem_v11_E30N50.zip": "4a21a88f4d2047b8995d1101df0b3a77", | ||
"eu_dem_v11_E40N10.zip": "32fdf4572581eddc305a21c5d2f4bc81", | ||
"eu_dem_v11_E40N20.zip": "71b027f29258493dd751cfd63f08578f", | ||
"eu_dem_v11_E40N30.zip": "c6c21289882c1f74fc4649d255302c64", | ||
"eu_dem_v11_E40N40.zip": "9f26e6e47f4160ef8ea5200e8cf90a45", | ||
"eu_dem_v11_E40N50.zip": "a8c3c1c026cdd1537b8a3822c15834d9", | ||
"eu_dem_v11_E50N10.zip": "9584273c7708b8e935f2bac3e30c19c6", | ||
"eu_dem_v11_E50N20.zip": "8efdea43e7b6819861935d5a768a55f2", | ||
"eu_dem_v11_E50N30.zip": "e39e58df1c13ac35eb0b29fb651f313c", | ||
"eu_dem_v11_E50N40.zip": "d84395ab52ad254d930db17398fffc50", | ||
"eu_dem_v11_E50N50.zip": "6abe852f4a20962db0e355ffc0d695a4", | ||
"eu_dem_v11_E60N10.zip": "b6a3b8a39a4efc01c7e2cd8418672559", | ||
"eu_dem_v11_E60N20.zip": "71dc3c55ab5c90628ce2149dbd60f090", | ||
"eu_dem_v11_E70N20.zip": "5342465ad60cf7d28a586c9585179c35", | ||
} | ||
|
||
def __init__( | ||
self, | ||
root: str = "data", | ||
crs: Optional[CRS] = None, | ||
res: Optional[float] = None, | ||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, | ||
cache: bool = True, | ||
checksum: bool = False, | ||
) -> None: | ||
"""Initialize a new Dataset instance. | ||
Args: | ||
root: root directory where dataset can be found, here the collection of | ||
individual zip files for each tile should be found | ||
crs: :term:`coordinate reference system (CRS)` to warp to | ||
(defaults to the CRS of the first file found) | ||
res: resolution of the dataset in units of CRS | ||
(defaults to the resolution of the first file found) | ||
transforms: a function/transform that takes an input sample | ||
and returns a transformed version | ||
cache: if True, cache file handle to speed up repeated sampling | ||
checksum: if True, check the MD5 of the downloaded files (may be slow) | ||
Raises: | ||
FileNotFoundError: if no files are found in ``root`` | ||
""" | ||
self.root = root | ||
self.checksum = checksum | ||
|
||
self._verify() | ||
|
||
super().__init__(root, crs, res, transforms, cache) | ||
|
||
def _verify(self) -> None: | ||
"""Verify the integrity of the dataset. | ||
Raises: | ||
RuntimeError: if dataset is missing or checksum fails | ||
""" | ||
# Check if the extracted file already exists | ||
pathname = os.path.join(self.root, self.filename_glob) | ||
if glob.glob(pathname): | ||
return | ||
|
||
# Check if the zip files have already been downloaded | ||
pathname = os.path.join(self.root, self.zipfile_glob) | ||
if glob.glob(pathname): | ||
for zipfile in glob.iglob(pathname): | ||
filename = os.path.basename(zipfile) | ||
if self.checksum and not check_integrity(zipfile, self.md5s[filename]): | ||
raise RuntimeError("Dataset found, but corrupted.") | ||
extract_archive(zipfile) | ||
return | ||
|
||
raise RuntimeError( | ||
f"Dataset not found in `root={self.root}` " | ||
"either specify a different `root` directory or make sure you " | ||
"have manually downloaded the dataset as suggested in the documentation." | ||
) | ||
|
||
def plot( # type: ignore[override] | ||
self, | ||
sample: Dict[str, Any], | ||
show_titles: bool = True, | ||
suptitle: Optional[str] = None, | ||
) -> plt.Figure: | ||
"""Plot a sample from the dataset. | ||
Args: | ||
sample: a sample returned by :meth:`RasterDataset.__getitem__` | ||
show_titles: flag indicating whether to show titles above each panel | ||
suptitle: optional string to use as a suptitle | ||
Returns: | ||
a matplotlib Figure with the rendered sample | ||
""" | ||
mask = sample["mask"].squeeze() | ||
ncols = 1 | ||
|
||
showing_predictions = "prediction" in sample | ||
if showing_predictions: | ||
pred = sample["prediction"].squeeze() | ||
ncols = 2 | ||
|
||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) | ||
|
||
if showing_predictions: | ||
axs[0].imshow(mask) | ||
axs[0].axis("off") | ||
axs[1].imshow(pred) | ||
axs[1].axis("off") | ||
if show_titles: | ||
axs[0].set_title("Mask") | ||
axs[1].set_title("Prediction") | ||
else: | ||
axs.imshow(mask) | ||
axs.axis("off") | ||
if show_titles: | ||
axs.set_title("Mask") | ||
|
||
if suptitle is not None: | ||
plt.suptitle(suptitle) | ||
|
||
return fig |