diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 0265aa58883..9f772a4e379 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -12,11 +12,16 @@ Geospatial Datasets :class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`. +Aboveground Live Woody Biomass Density +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: AbovegroundLiveWoodyBiomassDensity + Aster Global Digital Evaluation Model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: AsterGDEM - + Canadian Building Footprints ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/agb_live_woody_density/00N_000E.tif b/tests/data/agb_live_woody_density/00N_000E.tif new file mode 100644 index 00000000000..ec36d3169b1 Binary files /dev/null and b/tests/data/agb_live_woody_density/00N_000E.tif differ diff --git a/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson b/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson new file mode 100644 index 00000000000..169191641e1 --- /dev/null +++ b/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "name": "Aboveground_Live_Woody_Biomass_Density", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {"tile_id": "00N_000E", "download": "tests/data/agb_live_woody_density/00N_000E.tif", "ObjectId": 1, "Shape__Area": 1245542622548.87, "Shape__Length": 4464169.76558139}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/data/agb_live_woody_density/data.py b/tests/data/agb_live_woody_density/data.py new file mode 100644 index 00000000000..54ebba3e92c --- /dev/null +++ b/tests/data/agb_live_woody_density/data.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import os +import random + +import numpy as np +import rasterio + +SIZE = 32 + +np.random.seed(0) +random.seed(0) + + +base_file = { + "type": "FeatureCollection", + "name": "Aboveground_Live_Woody_Biomass_Density", + "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, + "features": [ + { + "type": "Feature", + "properties": { + "tile_id": "00N_000E", + "download": os.path.join( + "tests", "data", "agb_live_woody_density", "00N_000E.tif" + ), + "ObjectId": 1, + "Shape__Area": 1245542622548.8701, + "Shape__Length": 4464169.7655813899, + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]] + ], + }, + } + ], +} + + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + profile["compress"] = "lzw" + profile["predictor"] = 2 + + if "float" in profile["dtype"]: + Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + else: + Z = np.random.randint( + np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + ) + + src = rasterio.open(path, "w", **profile) + for i in range(1, profile["count"] + 1): + src.write(Z, i) + + +if __name__ == "__main__": + base_file_name = "Aboveground_Live_Woody_Biomass_Density.geojson" + if os.path.exists(base_file_name): + os.remove(base_file_name) + + with open(base_file_name, "w") as f: + json.dump(base_file, f) + + for i in base_file["features"]: + filepath = os.path.basename(i["properties"]["download"]) + create_file(path=filepath, dtype="int32", num_channels=1) diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py new file mode 100644 index 00000000000..1a145096ca4 --- /dev/null +++ b/tests/datasets/test_agb_live_woody_density.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo +from torchgeo.datasets import ( + AbovegroundLiveWoodyBiomassDensity, + IntersectionDataset, + UnionDataset, +) + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestAbovegroundLiveWoodyBiomassDensity: + @pytest.fixture + def dataset( + self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path + ) -> AbovegroundLiveWoodyBiomassDensity: + + transforms = nn.Identity() # type: ignore[attr-defined] + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.agb_live_woody_density, "download_url", download_url + ) + url = os.path.join( + "tests", + "data", + "agb_live_woody_density", + "Aboveground_Live_Woody_Biomass_Density.geojson", + ) + monkeypatch.setattr( # type: ignore[attr-defined] + AbovegroundLiveWoodyBiomassDensity, "url", url + ) + + root = str(tmp_path) + return AbovegroundLiveWoodyBiomassDensity( + root, transforms=transforms, download=True + ) + + def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_no_dataset(self) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in."): + AbovegroundLiveWoodyBiomassDensity(root="/test") + + def test_already_downloaded( + self, dataset: AbovegroundLiveWoodyBiomassDensity + ) -> None: + AbovegroundLiveWoodyBiomassDensity(dataset.root) + + def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_plot(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 531c0a781b9..8b16835464a 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -4,6 +4,7 @@ """TorchGeo datasets.""" from .advance import ADVANCE +from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity from .astergdem import AsterGDEM from .benin_cashews import BeninSmallHolderCashews from .bigearthnet import BigEarthNet @@ -86,6 +87,7 @@ __all__ = ( # GeoDataset + "AbovegroundLiveWoodyBiomassDensity", "AsterGDEM", "CanadianBuildingFootprints", "CDL", diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py new file mode 100644 index 00000000000..79f2601e5f4 --- /dev/null +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Aboveground Live Woody Biomass Density dataset.""" + +import glob +import json +import os +from typing import Any, Callable, Dict, Optional + +import matplotlib.pyplot as plt +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import download_url + + +class AbovegroundLiveWoodyBiomassDensity(RasterDataset): + """Aboveground Live Woody Biomass Density dataset. + + The `Aboveground Live Woody Biomass Density dataset + `_ + is a global-scale, wall-to-wall map of aboveground biomass at ~30m resolution + for the year 2000. + + Dataset features: + + * Masks with per pixel live woody biomass density estimates in megagrams + biomass per hectare at ~30m resolution (~40,000x40,0000 px) + + Dataset format: + + * geojson file that contains download links to tif files + * single-channel geotiffs with the pixel values representing biomass density + + If you use this dataset in your research, please give credit to: + + * `Global Forest Watch `_ + + .. versionadded:: 0.3 + """ + + is_image = False + + url = ( + "https://opendata.arcgis.com/api/v3/datasets/3e8736c8866b458687" + "e00d40c9f00bce_0/downloads/data?format=geojson&spatialRefId=4326" + ) + + base_filename = "Aboveground_Live_Woody_Biomass_Density.geojson" + + filename_glob = "*N_*E.*" + filename_regex = r"""^ + (?P[0-9][0-9][A-Z])_ + (?P[0-9][0-9][0-9][A-Z])* + """ + + def __init__( + self, + root: str = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + cache: bool = True, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + transforms: a function/transform that takes an input sample + and returns a transformed version + download: if True, download dataset and store it in the root directory + cache: if True, cache file handle to speed up repeated sampling + + Raises: + FileNotFoundError: if no files are found in ``root`` + """ + self.root = root + self.download = download + + self._verify() + + super().__init__(root, crs, res, transforms, cache) + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if dataset is missing + """ + # Check if the extracted files already exist + pathname = os.path.join(self.root, self.filename_glob) + if glob.glob(pathname): + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + download_url(self.url, self.root, self.base_filename) + + with open(os.path.join(self.root, self.base_filename), "r") as f: + content = json.load(f) + + for item in content["features"]: + download_url( + item["properties"]["download"], + self.root, + item["properties"]["tile_id"] + ".tif", + ) + + def plot( # type: ignore[override] + self, + sample: Dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze() + ncols = 1 + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze() + ncols = 2 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) + + if showing_predictions: + axs[0].imshow(mask) + axs[0].axis("off") + axs[1].imshow(pred) + axs[1].axis("off") + if show_titles: + axs[0].set_title("Mask") + axs[1].set_title("Prediction") + else: + axs.imshow(mask) + axs.axis("off") + if show_titles: + axs.set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig