Skip to content

Commit

Permalink
Allow RasterDataset to accept list of files (#1442)
Browse files Browse the repository at this point in the history
* Make RasterDataset accept list of files

* Fix check if str

* Use isdir and isfile

* Rename root to paths and update type hint

* Update children of RasterDataset methods using root

* Fix check to cast str to list

* Update conf files for RasterDatasets

* Add initial suggested test

* Add workaround for lists LandCoverAIBase

* Add method handle_nonlocal_path for users to override

* Raise RuntimeError to support existing tests

* Remove reduntand cast to set

* Remove required os.exists for paths

* Revert "Remove required os.exists for paths"

This reverts commit 84bf62b.

* Use arg  as potitional argument not kwarg

* Improve comments and logs about arg paths

* Remove misleading comment

* Change type hint of 'paths' to Iterable

* Change type hint of 'paths' to Iterable

* Remove premature handling of non-local paths

* Replace root with paths in docstrings

* Add versionadded to list_files docstring

* Add versionchanged to docstrings

* Update type of paths in childred of Raster

* Replace docstring for paths in all raster

* Swap root with paths for conf files for raster

* Add newline before versionchanged

* Revert name to root in conf for ChesapeakeCVPR

* Simplify EUDEM tests

* paths must be a string if you want autodownload support

* Convert list_files to a property

* Fix type hints

* Test with a real empty directory

* More diverse tests

* LandCoverAI: don't yet support list of paths

* Black

* isort

---------

Co-authored-by: Adrian Tofting <adriantofting@mobmob14994.hq.k.grp>
Co-authored-by: Adrian Tofting <adrian@vake.ai>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
4 people authored Sep 29, 2023
1 parent 51ffb69 commit 3cef4fb
Show file tree
Hide file tree
Showing 36 changed files with 323 additions and 222 deletions.
2 changes: 1 addition & 1 deletion conf/l7irish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ data:
patch_size: 224
num_workers: 16
dict_kwargs:
root: "data/l7irish"
paths: "data/l7irish"
2 changes: 1 addition & 1 deletion conf/l8biome.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ data:
patch_size: 224
num_workers: 16
dict_kwargs:
root: "data/l8biome"
paths: "data/l8biome"
4 changes: 2 additions & 2 deletions conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ data:
num_workers: 4
patch_size: 32
dict_kwargs:
naip_root: "data/naip"
chesapeake_root: "data/chesapeake/BAYWIDE"
naip_paths: "data/naip"
chesapeake_paths: "data/chesapeake/BAYWIDE"
2 changes: 1 addition & 1 deletion tests/conf/l7irish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ data:
patch_size: 32
length: 5
dict_kwargs:
root: "tests/data/l7irish"
paths: "tests/data/l7irish"
download: true
2 changes: 1 addition & 1 deletion tests/conf/l8biome.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ data:
patch_size: 32
length: 5
dict_kwargs:
root: "tests/data/l8biome"
paths: "tests/data/l8biome"
download: true
4 changes: 2 additions & 2 deletions tests/conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ data:
batch_size: 2
patch_size: 32
dict_kwargs:
naip_root: "tests/data/naip"
chesapeake_root: "tests/data/chesapeake/BAYWIDE"
naip_paths: "tests/data/naip"
chesapeake_paths: "tests/data/chesapeake/BAYWIDE"
chesapeake_download: true
8 changes: 4 additions & 4 deletions tests/datasets/test_agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_no_dataset(self) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
AbovegroundLiveWoodyBiomassDensity(root="/test")
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))

def test_already_downloaded(
self, dataset: AbovegroundLiveWoodyBiomassDensity
) -> None:
AbovegroundLiveWoodyBiomassDensity(dataset.root)
AbovegroundLiveWoodyBiomassDensity(dataset.paths)

def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
ds = dataset & dataset
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_datasetmissing(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
AsterGDEM(root=str(tmp_path))
AsterGDEM(str(tmp_path))

def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_full_year(self, dataset: CDL) -> None:
next(dataset.index.intersection(tuple(query)))

def test_already_extracted(self, dataset: CDL) -> None:
CDL(root=dataset.root, years=[2020, 2021])
CDL(dataset.paths, years=[2020, 2021])

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_or(self, dataset: Chesapeake13) -> None:
assert isinstance(ds, UnionDataset)

def test_already_extracted(self, dataset: Chesapeake13) -> None:
Chesapeake13(root=dataset.root, download=True)
Chesapeake13(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_getitem(self, dataset: CMSGlobalMangroveCanopy) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_no_dataset(self) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
CMSGlobalMangroveCanopy(root="/test")
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
CMSGlobalMangroveCanopy(str(tmp_path))

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
Expand All @@ -65,7 +65,7 @@ def test_corrupted(self, tmp_path: Path) -> None:
) as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
CMSGlobalMangroveCanopy(root=str(tmp_path), country="Angola", checksum=True)
CMSGlobalMangroveCanopy(str(tmp_path), country="Angola", checksum=True)

def test_invalid_country(self) -> None:
with pytest.raises(AssertionError):
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_esri2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_getitem(self, dataset: Esri2020) -> None:
assert isinstance(x["mask"], torch.Tensor)

def test_already_extracted(self, dataset: Esri2020) -> None:
Esri2020(root=dataset.root, download=True)
Esri2020(dataset.paths, download=True)

def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join(
Expand All @@ -57,7 +57,7 @@ def test_not_extracted(self, tmp_path: Path) -> None:
"io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip",
)
shutil.copy(url, tmp_path)
Esri2020(root=str(tmp_path))
Esri2020(str(tmp_path))

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
Expand Down
11 changes: 6 additions & 5 deletions tests/datasets/test_eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,22 @@ def test_getitem(self, dataset: EUDEM) -> None:
assert isinstance(x["mask"], torch.Tensor)

def test_extracted_already(self, dataset: EUDEM) -> None:
zipfile = os.path.join(dataset.root, "eu_dem_v11_E30N10.zip")
shutil.unpack_archive(zipfile, dataset.root, "zip")
EUDEM(dataset.root)
assert isinstance(dataset.paths, str)
zipfile = os.path.join(dataset.paths, "eu_dem_v11_E30N10.zip")
shutil.unpack_archive(zipfile, dataset.paths, "zip")
EUDEM(dataset.paths)

def test_no_dataset(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
EUDEM(root=str(tmp_path))
EUDEM(str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
EUDEM(root=str(tmp_path), checksum=True)
EUDEM(str(tmp_path), checksum=True)

def test_and(self, dataset: EUDEM) -> None:
ds = dataset & dataset
Expand Down
36 changes: 35 additions & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import pickle
from collections.abc import Iterable
from pathlib import Path
from typing import Union

import pytest
import torch
Expand Down Expand Up @@ -178,6 +179,39 @@ def sentinel(self, request: SubRequest) -> Sentinel2:
cache = request.param[1]
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)

@pytest.mark.parametrize(
"paths",
[
# Single directory
os.path.join("tests", "data", "naip"),
# Multiple directories
[
os.path.join("tests", "data", "naip"),
os.path.join("tests", "data", "naip"),
],
# Single file
os.path.join("tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"),
# Multiple files
(
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
),
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20190605.tif"
),
),
# Combination
{
os.path.join("tests", "data", "naip"),
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
),
},
],
)
def test_files(self, paths: Union[str, Iterable[str]]) -> None:
assert 1 <= len(NAIP(paths).files) <= 2

def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_getitem(self, dataset: GlobBiomass) -> None:
assert isinstance(x["mask"], torch.Tensor)

def test_already_extracted(self, dataset: GlobBiomass) -> None:
GlobBiomass(root=dataset.root)
GlobBiomass(dataset.paths)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
Expand All @@ -57,7 +57,7 @@ def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
GlobBiomass(root=str(tmp_path), checksum=True)
GlobBiomass(str(tmp_path), checksum=True)

def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_plot(self, dataset: L7Irish) -> None:
plt.close()

def test_already_extracted(self, dataset: L7Irish) -> None:
L7Irish(root=dataset.root, download=True)
L7Irish(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz")
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_rgb_bands_absent_plot(self, dataset: L7Irish) -> None:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds = L7Irish(root=dataset.root, bands=["B10", "B20", "B50"])
ds = L7Irish(dataset.paths, bands=["B10", "B20", "B50"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
4 changes: 2 additions & 2 deletions tests/datasets/test_l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_plot(self, dataset: L8Biome) -> None:
plt.close()

def test_already_extracted(self, dataset: L8Biome) -> None:
L8Biome(root=dataset.root, download=True)
L8Biome(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l8biome", "*.tar.gz")
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_rgb_bands_absent_plot(self, dataset: L8Biome) -> None:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds = L8Biome(root=dataset.root, bands=["B1", "B2", "B5"])
ds = L8Biome(dataset.paths, bands=["B1", "B2", "B5"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
2 changes: 1 addition & 1 deletion tests/datasets/test_landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_getitem(self, dataset: LandCoverAIGeo) -> None:
assert isinstance(x["mask"], torch.Tensor)

def test_already_extracted(self, dataset: LandCoverAIGeo) -> None:
LandCoverAIGeo(root=dataset.root, download=True)
LandCoverAIGeo(dataset.root, download=True)

def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_plot(self, dataset: Landsat8) -> None:

def test_plot_wrong_bands(self, dataset: Landsat8) -> None:
bands = ("SR_B1",)
ds = Landsat8(root=dataset.root, bands=bands)
ds = Landsat8(dataset.paths, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_nlcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_or(self, dataset: NLCD) -> None:
assert isinstance(ds, UnionDataset)

def test_already_extracted(self, dataset: NLCD) -> None:
NLCD(root=dataset.root, download=True, years=[2019])
NLCD(dataset.paths, download=True, years=[2019])

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_plot(self, dataset: Sentinel2) -> None:

def test_plot_wrong_bands(self, dataset: Sentinel2) -> None:
bands = ["B02"]
ds = Sentinel2(root=dataset.root, res=dataset.res, bands=bands)
ds = Sentinel2(dataset.paths, res=dataset.res, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
Expand Down
29 changes: 16 additions & 13 deletions torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

"""Aboveground Live Woody Biomass Density dataset."""

import glob
import json
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -59,7 +59,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -69,7 +69,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -80,14 +80,17 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download

self._verify()

super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)

def _verify(self) -> None:
"""Verify the integrity of the dataset.
Expand All @@ -96,14 +99,13 @@ def _verify(self) -> None:
RuntimeError: if dataset is missing
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
if self.files:
return

# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
Expand All @@ -113,15 +115,16 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, self.base_filename)
assert isinstance(self.paths, str)
download_url(self.url, self.paths, self.base_filename)

with open(os.path.join(self.root, self.base_filename)) as f:
with open(os.path.join(self.paths, self.base_filename)) as f:
content = json.load(f)

for item in content["features"]:
download_url(
item["properties"]["download"],
self.root,
self.paths,
item["properties"]["tile_id"] + ".tif",
)

Expand Down
Loading

0 comments on commit 3cef4fb

Please sign in to comment.