Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AGB live woody density dataset #425

Merged
merged 13 commits into from
Feb 27, 2022
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Geospatial Datasets

:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`.

Aboveground Live Woody Biomass Density
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: AbovegroundLiveWoodyBiomassDensity

Canadian Building Footprints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Binary file added tests/data/agb_live_woody_density/00N_000E.tif
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "name": "Aboveground_Live_Woody_Biomass_Density", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {"tile_id": "00N_000E", "download": "tests/data/agb_live_woody_density/00N_000E.tif", "ObjectId": 1, "Shape__Area": 1245542622548.87, "Shape__Length": 4464169.76558139}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]]]}}]}
80 changes: 80 additions & 0 deletions tests/data/agb_live_woody_density/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json
import os
import random

import numpy as np
import rasterio

SIZE = 32

np.random.seed(0)
random.seed(0)


base_file = {
"type": "FeatureCollection",
"name": "Aboveground_Live_Woody_Biomass_Density",
"crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}},
"features": [
{
"type": "Feature",
"properties": {
"tile_id": "00N_000E",
"download": os.path.join(
"tests", "data", "agb_live_woody_density", "00N_000E.tif"
),
"ObjectId": 1,
"Shape__Area": 1245542622548.8701,
"Shape__Length": 4464169.7655813899,
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]]
],
},
}
],
}


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2

if "float" in profile["dtype"]:
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
else:
Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
)

src = rasterio.open(path, "w", **profile)
for i in range(1, profile["count"] + 1):
src.write(Z, i)


if __name__ == "__main__":
base_file_name = "Aboveground_Live_Woody_Biomass_Density.geojson"
if os.path.exists(base_file_name):
os.remove(base_file_name)

with open(base_file_name, "w") as f:
json.dump(base_file, f)

for i in base_file["features"]:
filepath = os.path.basename(i["properties"]["download"])
create_file(path=filepath, dtype="int32", num_channels=1)
100 changes: 100 additions & 0 deletions tests/datasets/test_agb_live_woody_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

import torchgeo
from torchgeo.datasets import (
AbovegroundLiveWoodyBiomassDensity,
IntersectionDataset,
UnionDataset,
)


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestAbovegroundLiveWoodyBiomassDensity:
@pytest.fixture
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> AbovegroundLiveWoodyBiomassDensity:

transforms = nn.Identity() # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.agb_live_woody_density, "download_url", download_url
)
url = os.path.join(
"tests",
"data",
"agb_live_woody_density",
"Aboveground_Live_Woody_Biomass_Density.geojson",
)
monkeypatch.setattr( # type: ignore[attr-defined]
AbovegroundLiveWoodyBiomassDensity, "url", url
)

root = str(tmp_path)
return AbovegroundLiveWoodyBiomassDensity(
root, transforms=transforms, download=True
)

def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_no_dataset(self) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
AbovegroundLiveWoodyBiomassDensity(root="/test")

def test_no_basefile(
self, dataset: AbovegroundLiveWoodyBiomassDensity, tmp_path: Path
) -> None:
os.remove(os.path.join(str(tmp_path), "00N_000E.tif"))
AbovegroundLiveWoodyBiomassDensity(dataset.root)

def test_already_downloaded(self, tmp_path: Path) -> None:
base_file_path = os.path.join(
"tests",
"data",
"agb_live_woody_density",
"Aboveground_Live_Woody_Biomass_Density.geojson",
)
tif_pathname = os.path.join(
"tests", "data", "agb_live_woody_density", "00N_000E.tif"
)
root = str(tmp_path)
shutil.copy(base_file_path, root)
shutil.copy(tif_pathname, root)
AbovegroundLiveWoodyBiomassDensity(root)

def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x, suptitle="Test")

def test_plot_prediction(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""TorchGeo datasets."""

from .advance import ADVANCE
from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .cbf import CanadianBuildingFootprints
Expand Down Expand Up @@ -85,6 +86,7 @@

__all__ = (
# GeoDataset
"AbovegroundLiveWoodyBiomassDensity",
"CanadianBuildingFootprints",
"CDL",
"Chesapeake",
Expand Down
178 changes: 178 additions & 0 deletions torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Aboveground Live Woody Biomass Density dataset."""

import glob
import json
import os
from typing import Any, Callable, Dict, Optional

import matplotlib.pyplot as plt
from rasterio.crs import CRS
from torch import Tensor

from .geo import RasterDataset
from .utils import download_url


class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
"""Aboveground Live Woody Biomass Density dataset.

The `Aboveground Live Woody Biomass Density dataset
<https://data.globalforestwatch.org/datasets/gfw::aboveground-live-woody
-biomass-density/about>`_
is a global-scale, wall-to-wall map of aboveground biomass at ~30m resolution
for the year 2000.

Dataset features:

* Masks with per pixel live woody biomass density estimates in megagrams
biomass per hectare at ~30m resolution (~40,000x40,0000 px)

Dataset format:

* geojson file that contains download links to tif files
* single-channel geotiffs with the pixel values representing biomass density

If you use this dataset in your research, please give credit to:

* `Global Forest Watch <https://data.globalforestwatch.org/>`_

.. versionadded:: 0.3
"""

is_image = False

url = (
"https://opendata.arcgis.com/api/v3/datasets/3e8736c8866b458687"
"e00d40c9f00bce_0/downloads/data?format=geojson&spatialRefId=4326"
)

base_filename = "Aboveground_Live_Woody_Biomass_Density.geojson"

filename_glob = "*N_*E.*"
filename_regex = (
r"""^(?P<latitude>[0-9][0-9][A-Z])_(?P<longitude>[0-9][0-9][0-9][A-Z])*"""
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(
self,
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
transforms: a function/transform that takes an input sample
and returns a transformed version
download: if True, download dataset and store it in the root directory
cache: if True, cache file handle to speed up repeated sampling

Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.download = download

self._verify()

super().__init__(root, crs, res, transforms, cache)

def _verify(self) -> None:
"""Verify the integrity of the dataset.

Raises:
RuntimeError: if dataset is missing
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
return

# Check if the downloaded base file already exist
pathname = os.path.join(self.root, self.base_filename)
if glob.glob(pathname):
self._extract()
return

# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)

# Download the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, self.base_filename)

def _extract(self) -> None:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
with open(os.path.join(self.root, self.base_filename), "r") as f:
content = json.load(f)

for item in content["features"]:
download_url(
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
item["properties"]["download"],
self.root,
item["properties"]["tile_id"] + ".tif",
)

def plot( # type: ignore[override]
self,
sample: Dict[str, Tensor],
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(pred)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")

if suptitle is not None:
plt.suptitle(suptitle)

return fig