From 3cef4fb21dd1ac55fdfa531b4459c8249c099782 Mon Sep 17 00:00:00 2001 From: Adrian Tofting Date: Fri, 29 Sep 2023 16:28:07 +0200 Subject: [PATCH] Allow RasterDataset to accept list of files (#1442) * Make RasterDataset accept list of files * Fix check if str * Use isdir and isfile * Rename root to paths and update type hint * Update children of RasterDataset methods using root * Fix check to cast str to list * Update conf files for RasterDatasets * Add initial suggested test * Add workaround for lists LandCoverAIBase * Add method handle_nonlocal_path for users to override * Raise RuntimeError to support existing tests * Remove reduntand cast to set * Remove required os.exists for paths * Revert "Remove required os.exists for paths" This reverts commit 84bf62b944326c33d5ba8efdcab615c65b124792. * Use arg as potitional argument not kwarg * Improve comments and logs about arg paths * Remove misleading comment * Change type hint of 'paths' to Iterable * Change type hint of 'paths' to Iterable * Remove premature handling of non-local paths * Replace root with paths in docstrings * Add versionadded to list_files docstring * Add versionchanged to docstrings * Update type of paths in childred of Raster * Replace docstring for paths in all raster * Swap root with paths for conf files for raster * Add newline before versionchanged * Revert name to root in conf for ChesapeakeCVPR * Simplify EUDEM tests * paths must be a string if you want autodownload support * Convert list_files to a property * Fix type hints * Test with a real empty directory * More diverse tests * LandCoverAI: don't yet support list of paths * Black * isort --------- Co-authored-by: Adrian Tofting Co-authored-by: Adrian Tofting Co-authored-by: Adam J. Stewart --- conf/l7irish.yaml | 2 +- conf/l8biome.yaml | 2 +- conf/naipchesapeake.yaml | 4 +- tests/conf/l7irish.yaml | 2 +- tests/conf/l8biome.yaml | 2 +- tests/conf/naipchesapeake.yaml | 4 +- tests/datasets/test_agb_live_woody_density.py | 8 +-- tests/datasets/test_astergdem.py | 2 +- tests/datasets/test_cdl.py | 2 +- tests/datasets/test_chesapeake.py | 2 +- tests/datasets/test_cms_mangrove_canopy.py | 8 +-- tests/datasets/test_esri2020.py | 4 +- tests/datasets/test_eudem.py | 11 +++-- tests/datasets/test_geo.py | 36 +++++++++++++- tests/datasets/test_globbiomass.py | 4 +- tests/datasets/test_l7irish.py | 4 +- tests/datasets/test_l8biome.py | 4 +- tests/datasets/test_landcoverai.py | 2 +- tests/datasets/test_landsat.py | 2 +- tests/datasets/test_nlcd.py | 2 +- tests/datasets/test_sentinel.py | 2 +- torchgeo/datasets/agb_live_woody_density.py | 29 ++++++----- torchgeo/datasets/astergdem.py | 24 ++++----- torchgeo/datasets/cdl.py | 39 +++++++-------- torchgeo/datasets/chesapeake.py | 29 ++++++----- torchgeo/datasets/cms_mangrove_canopy.py | 27 +++++----- torchgeo/datasets/esri2020.py | 29 ++++++----- torchgeo/datasets/eudem.py | 26 +++++----- torchgeo/datasets/geo.py | 49 +++++++++++++++---- torchgeo/datasets/globbiomass.py | 24 +++++---- torchgeo/datasets/l7irish.py | 28 ++++++----- torchgeo/datasets/l8biome.py | 28 ++++++----- torchgeo/datasets/landcoverai.py | 24 ++++----- torchgeo/datasets/landsat.py | 15 +++--- torchgeo/datasets/nlcd.py | 39 +++++++-------- torchgeo/datasets/sentinel.py | 26 ++++++---- 36 files changed, 323 insertions(+), 222 deletions(-) diff --git a/conf/l7irish.yaml b/conf/l7irish.yaml index 6681379cf11..91b1cbea15d 100644 --- a/conf/l7irish.yaml +++ b/conf/l7irish.yaml @@ -20,4 +20,4 @@ data: patch_size: 224 num_workers: 16 dict_kwargs: - root: "data/l7irish" + paths: "data/l7irish" diff --git a/conf/l8biome.yaml b/conf/l8biome.yaml index 122a3c46073..728073a56fa 100644 --- a/conf/l8biome.yaml +++ b/conf/l8biome.yaml @@ -20,4 +20,4 @@ data: patch_size: 224 num_workers: 16 dict_kwargs: - root: "data/l8biome" + paths: "data/l8biome" diff --git a/conf/naipchesapeake.yaml b/conf/naipchesapeake.yaml index a767e6348e6..f03c759b226 100644 --- a/conf/naipchesapeake.yaml +++ b/conf/naipchesapeake.yaml @@ -21,5 +21,5 @@ data: num_workers: 4 patch_size: 32 dict_kwargs: - naip_root: "data/naip" - chesapeake_root: "data/chesapeake/BAYWIDE" + naip_paths: "data/naip" + chesapeake_paths: "data/chesapeake/BAYWIDE" diff --git a/tests/conf/l7irish.yaml b/tests/conf/l7irish.yaml index 26a47518b4b..fc67fb8e1cc 100644 --- a/tests/conf/l7irish.yaml +++ b/tests/conf/l7irish.yaml @@ -15,5 +15,5 @@ data: patch_size: 32 length: 5 dict_kwargs: - root: "tests/data/l7irish" + paths: "tests/data/l7irish" download: true diff --git a/tests/conf/l8biome.yaml b/tests/conf/l8biome.yaml index e52eabbc1e3..f33b4b36464 100644 --- a/tests/conf/l8biome.yaml +++ b/tests/conf/l8biome.yaml @@ -15,5 +15,5 @@ data: patch_size: 32 length: 5 dict_kwargs: - root: "tests/data/l8biome" + paths: "tests/data/l8biome" download: true diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml index 99e59f30e5d..4b13865f1bd 100644 --- a/tests/conf/naipchesapeake.yaml +++ b/tests/conf/naipchesapeake.yaml @@ -14,6 +14,6 @@ data: batch_size: 2 patch_size: 32 dict_kwargs: - naip_root: "tests/data/naip" - chesapeake_root: "tests/data/chesapeake/BAYWIDE" + naip_paths: "tests/data/naip" + chesapeake_paths: "tests/data/chesapeake/BAYWIDE" chesapeake_download: true diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py index 7800aecc1b1..ae775a1cabf 100644 --- a/tests/datasets/test_agb_live_woody_density.py +++ b/tests/datasets/test_agb_live_woody_density.py @@ -52,14 +52,14 @@ def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: assert isinstance(x["crs"], CRS) assert isinstance(x["mask"], torch.Tensor) - def test_no_dataset(self) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in."): - AbovegroundLiveWoodyBiomassDensity(root="/test") + def test_no_dataset(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + AbovegroundLiveWoodyBiomassDensity(str(tmp_path)) def test_already_downloaded( self, dataset: AbovegroundLiveWoodyBiomassDensity ) -> None: - AbovegroundLiveWoodyBiomassDensity(dataset.root) + AbovegroundLiveWoodyBiomassDensity(dataset.paths) def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py index 25d0940b30d..0a1d8fc263a 100644 --- a/tests/datasets/test_astergdem.py +++ b/tests/datasets/test_astergdem.py @@ -27,7 +27,7 @@ def test_datasetmissing(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(RuntimeError, match="Dataset not found in"): - AsterGDEM(root=str(tmp_path)) + AsterGDEM(str(tmp_path)) def test_getitem(self, dataset: AsterGDEM) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index e5badeeb0fc..50babc4c175 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -74,7 +74,7 @@ def test_full_year(self, dataset: CDL) -> None: next(dataset.index.intersection(tuple(query))) def test_already_extracted(self, dataset: CDL) -> None: - CDL(root=dataset.root, years=[2020, 2021]) + CDL(dataset.paths, years=[2020, 2021]) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip") diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 3348c1f6c95..ba33de7c465 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -59,7 +59,7 @@ def test_or(self, dataset: Chesapeake13) -> None: assert isinstance(ds, UnionDataset) def test_already_extracted(self, dataset: Chesapeake13) -> None: - Chesapeake13(root=dataset.root, download=True) + Chesapeake13(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join( diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index 9a30a2a0283..3c9ea05e65a 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -44,9 +44,9 @@ def test_getitem(self, dataset: CMSGlobalMangroveCanopy) -> None: assert isinstance(x["crs"], CRS) assert isinstance(x["mask"], torch.Tensor) - def test_no_dataset(self) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in."): - CMSGlobalMangroveCanopy(root="/test") + def test_no_dataset(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + CMSGlobalMangroveCanopy(str(tmp_path)) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( @@ -65,7 +65,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ) as f: f.write("bad") with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - CMSGlobalMangroveCanopy(root=str(tmp_path), country="Angola", checksum=True) + CMSGlobalMangroveCanopy(str(tmp_path), country="Angola", checksum=True) def test_invalid_country(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 74f9200cb5a..60c963139a1 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -47,7 +47,7 @@ def test_getitem(self, dataset: Esri2020) -> None: assert isinstance(x["mask"], torch.Tensor) def test_already_extracted(self, dataset: Esri2020) -> None: - Esri2020(root=dataset.root, download=True) + Esri2020(dataset.paths, download=True) def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join( @@ -57,7 +57,7 @@ def test_not_extracted(self, tmp_path: Path) -> None: "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", ) shutil.copy(url, tmp_path) - Esri2020(root=str(tmp_path)) + Esri2020(str(tmp_path)) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py index fd04f541e9a..e3a5efdbe25 100644 --- a/tests/datasets/test_eudem.py +++ b/tests/datasets/test_eudem.py @@ -33,21 +33,22 @@ def test_getitem(self, dataset: EUDEM) -> None: assert isinstance(x["mask"], torch.Tensor) def test_extracted_already(self, dataset: EUDEM) -> None: - zipfile = os.path.join(dataset.root, "eu_dem_v11_E30N10.zip") - shutil.unpack_archive(zipfile, dataset.root, "zip") - EUDEM(dataset.root) + assert isinstance(dataset.paths, str) + zipfile = os.path.join(dataset.paths, "eu_dem_v11_E30N10.zip") + shutil.unpack_archive(zipfile, dataset.paths, "zip") + EUDEM(dataset.paths) def test_no_dataset(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(RuntimeError, match="Dataset not found in"): - EUDEM(root=str(tmp_path)) + EUDEM(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f: f.write("bad") with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - EUDEM(root=str(tmp_path), checksum=True) + EUDEM(str(tmp_path), checksum=True) def test_and(self, dataset: EUDEM) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index cf4ef25d880..6b73bb0e8a4 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - import os import pickle +from collections.abc import Iterable from pathlib import Path +from typing import Union import pytest import torch @@ -178,6 +179,39 @@ def sentinel(self, request: SubRequest) -> Sentinel2: cache = request.param[1] return Sentinel2(root, bands=bands, transforms=transforms, cache=cache) + @pytest.mark.parametrize( + "paths", + [ + # Single directory + os.path.join("tests", "data", "naip"), + # Multiple directories + [ + os.path.join("tests", "data", "naip"), + os.path.join("tests", "data", "naip"), + ], + # Single file + os.path.join("tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"), + # Multiple files + ( + os.path.join( + "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif" + ), + os.path.join( + "tests", "data", "naip", "m_3807511_ne_18_060_20190605.tif" + ), + ), + # Combination + { + os.path.join("tests", "data", "naip"), + os.path.join( + "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif" + ), + }, + ], + ) + def test_files(self, paths: Union[str, Iterable[str]]) -> None: + assert 1 <= len(NAIP(paths).files) <= 2 + def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] assert isinstance(x, dict) diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py index 2bc94bdc00c..c73675def91 100644 --- a/tests/datasets/test_globbiomass.py +++ b/tests/datasets/test_globbiomass.py @@ -47,7 +47,7 @@ def test_getitem(self, dataset: GlobBiomass) -> None: assert isinstance(x["mask"], torch.Tensor) def test_already_extracted(self, dataset: GlobBiomass) -> None: - GlobBiomass(root=dataset.root) + GlobBiomass(dataset.paths) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): @@ -57,7 +57,7 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f: f.write("bad") with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - GlobBiomass(root=str(tmp_path), checksum=True) + GlobBiomass(str(tmp_path), checksum=True) def test_and(self, dataset: GlobBiomass) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index 59610d78f62..43795d57a1e 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -58,7 +58,7 @@ def test_plot(self, dataset: L7Irish) -> None: plt.close() def test_already_extracted(self, dataset: L7Irish) -> None: - L7Irish(root=dataset.root, download=True) + L7Irish(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz") @@ -88,7 +88,7 @@ def test_rgb_bands_absent_plot(self, dataset: L7Irish) -> None: with pytest.raises( ValueError, match="Dataset doesn't contain some of the RGB bands" ): - ds = L7Irish(root=dataset.root, bands=["B10", "B20", "B50"]) + ds = L7Irish(dataset.paths, bands=["B10", "B20", "B50"]) x = ds[ds.bounds] ds.plot(x, suptitle="Test") plt.close() diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index 2b653c87061..337e17c1b64 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -58,7 +58,7 @@ def test_plot(self, dataset: L8Biome) -> None: plt.close() def test_already_extracted(self, dataset: L8Biome) -> None: - L8Biome(root=dataset.root, download=True) + L8Biome(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join("tests", "data", "l8biome", "*.tar.gz") @@ -88,7 +88,7 @@ def test_rgb_bands_absent_plot(self, dataset: L8Biome) -> None: with pytest.raises( ValueError, match="Dataset doesn't contain some of the RGB bands" ): - ds = L8Biome(root=dataset.root, bands=["B1", "B2", "B5"]) + ds = L8Biome(dataset.paths, bands=["B1", "B2", "B5"]) x = ds[ds.bounds] ds.plot(x, suptitle="Test") plt.close() diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 4b371b70049..3e0b6d2434a 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -40,7 +40,7 @@ def test_getitem(self, dataset: LandCoverAIGeo) -> None: assert isinstance(x["mask"], torch.Tensor) def test_already_extracted(self, dataset: LandCoverAIGeo) -> None: - LandCoverAIGeo(root=dataset.root, download=True) + LandCoverAIGeo(dataset.root, download=True) def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index f4ce259f6ee..950d33fcb00 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -52,7 +52,7 @@ def test_plot(self, dataset: Landsat8) -> None: def test_plot_wrong_bands(self, dataset: Landsat8) -> None: bands = ("SR_B1",) - ds = Landsat8(root=dataset.root, bands=bands) + ds = Landsat8(dataset.paths, bands=bands) x = dataset[dataset.bounds] with pytest.raises( ValueError, match="Dataset doesn't contain some of the RGB bands" diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py index 0f0f134e384..ceee8097634 100644 --- a/tests/datasets/test_nlcd.py +++ b/tests/datasets/test_nlcd.py @@ -69,7 +69,7 @@ def test_or(self, dataset: NLCD) -> None: assert isinstance(ds, UnionDataset) def test_already_extracted(self, dataset: NLCD) -> None: - NLCD(root=dataset.root, download=True, years=[2019]) + NLCD(dataset.paths, download=True, years=[2019]) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index 2d6c42aa89e..fccb4e32032 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -133,7 +133,7 @@ def test_plot(self, dataset: Sentinel2) -> None: def test_plot_wrong_bands(self, dataset: Sentinel2) -> None: bands = ["B02"] - ds = Sentinel2(root=dataset.root, res=dataset.res, bands=bands) + ds = Sentinel2(dataset.paths, res=dataset.res, bands=bands) x = dataset[dataset.bounds] with pytest.raises( ValueError, match="Dataset doesn't contain some of the RGB bands" diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index e8804b2e384..ec9ae90bd67 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -3,10 +3,10 @@ """Aboveground Live Woody Biomass Density dataset.""" -import glob import json import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -59,7 +59,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -69,7 +69,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -80,14 +80,17 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.download = download self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -96,14 +99,13 @@ def _verify(self) -> None: RuntimeError: if dataset is missing """ # Check if the extracted files already exist - pathname = os.path.join(self.root, self.filename_glob) - if glob.glob(pathname): + if self.files: return # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -113,15 +115,16 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - download_url(self.url, self.root, self.base_filename) + assert isinstance(self.paths, str) + download_url(self.url, self.paths, self.base_filename) - with open(os.path.join(self.root, self.base_filename)) as f: + with open(os.path.join(self.paths, self.base_filename)) as f: content = json.load(f) for item in content["features"]: download_url( item["properties"]["download"], - self.root, + self.paths, item["properties"]["tile_id"] + ".tif", ) diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 7d2d2442688..305c6bf873c 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -3,9 +3,7 @@ """Aster Global Digital Elevation Model dataset.""" -import glob -import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -47,7 +45,7 @@ class AsterGDEM(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, list[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -56,8 +54,8 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found, here the collection of - individual zip files for each tile should be found + paths: one or more root directories to search or files to load, here + the collection of individual zip files for each tile should be found crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -67,14 +65,17 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if dataset is missing + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -83,12 +84,11 @@ def _verify(self) -> None: RuntimeError: if dataset is missing """ # Check if the extracted files already exists - pathname = os.path.join(self.root, self.filename_glob) - if glob.glob(pathname): + if self.files: return raise RuntimeError( - f"Dataset not found in `root={self.root}` " + f"Dataset not found in `root={self.paths}` " "either specify a different `root` directory or make sure you " "have manually downloaded dataset tiles as suggested in the documentation." ) diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 590a7c515db..39a7a105325 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -3,9 +3,9 @@ """CDL dataset.""" -import glob import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt import torch @@ -205,7 +205,7 @@ class CDL(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, years: list[int] = [2022], @@ -218,7 +218,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -234,11 +234,14 @@ def __init__( Raises: AssertionError: if ``years`` or ``classes`` are invalid - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails .. versionadded:: 0.5 The *years* and *classes* parameters. + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ assert set(years) <= self.md5s.keys(), ( "CDL data product only exists for the following years: " @@ -249,7 +252,7 @@ def __init__( ), f"Only the following classes are valid: {list(self.cmap.keys())}." assert 0 in classes, "Classes must include the background class: 0" - self.root = root + self.paths = paths self.years = years self.classes = classes self.download = download @@ -259,7 +262,7 @@ def __init__( self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) # Map chosen classes to ordinal numbers, all others mapped to background class for v, k in enumerate(self.classes): @@ -289,22 +292,15 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted files already exist - exists = [] - for year in self.years: - filename_year = self.filename_glob.replace("*", str(year)) - pathname = os.path.join(self.root, "**", filename_year) - for fname in glob.iglob(pathname, recursive=True): - if not fname.endswith(".zip"): - exists.append(True) - - if len(exists) == len(self.years): + if self.files: return # Check if the zip files have already been downloaded exists = [] + assert isinstance(self.paths, str) for year in self.years: pathname = os.path.join( - self.root, self.zipfile_glob.replace("*", str(year)) + self.paths, self.zipfile_glob.replace("*", str(year)) ) if os.path.exists(pathname): exists.append(True) @@ -318,7 +314,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -332,16 +328,17 @@ def _download(self) -> None: for year in self.years: download_url( self.url.format(year), - self.root, + self.paths, md5=self.md5s[year] if self.checksum else None, ) def _extract(self) -> None: """Extract the dataset.""" + assert isinstance(self.paths, str) for year in self.years: zipfile_name = self.zipfile_glob.replace("*", str(year)) - pathname = os.path.join(self.root, zipfile_name) - extract_archive(pathname, self.root) + pathname = os.path.join(self.paths, zipfile_name) + extract_archive(pathname, self.paths) def plot( self, diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7ca6125a246..1c17fe84bdd 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -6,8 +6,8 @@ import abc import os import sys -from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union, cast import fiona import matplotlib.pyplot as plt @@ -89,7 +89,7 @@ def url(self) -> str: def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -100,7 +100,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -112,10 +112,13 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.download = download self.checksum = checksum @@ -132,7 +135,7 @@ def __init__( ) self._cmap = ListedColormap(colors) - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -141,18 +144,19 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted file already exists - if os.path.exists(os.path.join(self.root, self.filename)): + if self.files: return # Check if the zip file has already been downloaded - if os.path.exists(os.path.join(self.root, self.zipfile)): + assert isinstance(self.paths, str) + if os.path.exists(os.path.join(self.paths, self.zipfile)): self._extract() return # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -163,11 +167,12 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - download_url(self.url, self.root, filename=self.zipfile, md5=self.md5) + download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5) def _extract(self) -> None: """Extract the dataset.""" - extract_archive(os.path.join(self.root, self.zipfile)) + assert isinstance(self.paths, str) + extract_archive(os.path.join(self.paths, self.zipfile)) def plot( self, diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index cc27b6496d9..df02764705f 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -3,9 +3,8 @@ """CMS Global Mangrove Canopy dataset.""" -import glob import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -168,7 +167,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, list[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, measurement: str = "agb", @@ -180,7 +179,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -193,11 +192,14 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if dataset is missing or checksum fails AssertionError: if country or measurement arg are not str or invalid + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.checksum = checksum assert isinstance(country, str), "Country argument must be a str." @@ -220,7 +222,7 @@ def __init__( self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -229,12 +231,12 @@ def _verify(self) -> None: RuntimeError: if dataset is missing or checksum fails """ # Check if the extracted files already exist - pathname = os.path.join(self.root, "**", self.filename_glob) - if glob.glob(pathname): + if self.files: return # Check if the zip file has already been downloaded - pathname = os.path.join(self.root, self.zipfile) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile) if os.path.exists(pathname): if self.checksum and not check_integrity(pathname, self.md5): raise RuntimeError("Dataset found, but corrupted.") @@ -242,14 +244,15 @@ def _verify(self) -> None: return raise RuntimeError( - f"Dataset not found in `root={self.root}` " + f"Dataset not found in `root={self.paths}` " "either specify a different `root` directory or make sure you " "have manually downloaded the dataset as instructed in the documentation." ) def _extract(self) -> None: """Extract the dataset.""" - pathname = os.path.join(self.root, self.zipfile) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile) extract_archive(pathname) def plot( diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 5d26afa8ee0..6b875b8e040 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -67,7 +68,7 @@ class Esri2020(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -78,7 +79,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -90,16 +91,19 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.download = download self.checksum = checksum self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -108,12 +112,12 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted file already exists - pathname = os.path.join(self.root, "**", self.filename_glob) - if glob.glob(pathname): + if self.files: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.zipfile) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile) if glob.glob(pathname): self._extract() return @@ -121,7 +125,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -132,11 +136,12 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - download_url(self.url, self.root, filename=self.zipfile, md5=self.md5) + download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5) def _extract(self) -> None: """Extract the dataset.""" - extract_archive(os.path.join(self.root, self.zipfile)) + assert isinstance(self.paths, str) + extract_archive(os.path.join(self.paths, self.zipfile)) def plot( self, diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index dacca0d6d63..35313dae075 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -82,7 +83,7 @@ class EUDEM(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -92,8 +93,8 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found, here the collection of - individual zip files for each tile should be found + paths: one or more root directories to search or files to load, here + the collection of individual zip files for each tile should be found crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -104,14 +105,17 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.checksum = checksum self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: """Verify the integrity of the dataset. @@ -120,12 +124,12 @@ def _verify(self) -> None: RuntimeError: if dataset is missing or checksum fails """ # Check if the extracted file already exists - pathname = os.path.join(self.root, self.filename_glob) - if glob.glob(pathname): + if self.files: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.zipfile_glob) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile_glob) if glob.glob(pathname): for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) @@ -135,7 +139,7 @@ def _verify(self) -> None: return raise RuntimeError( - f"Dataset not found in `root={self.root}` " + f"Dataset not found in `root={self.paths}` " "either specify a different `root` directory or make sure you " "have manually downloaded the dataset as suggested in the documentation." ) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index dc2b5fa1388..09564dbaba1 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -9,8 +9,8 @@ import os import re import sys -from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union, cast import fiona import fiona.transform @@ -329,7 +329,7 @@ def dtype(self) -> torch.dtype: def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, @@ -339,7 +339,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -350,19 +350,21 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ super().__init__(transforms) - self.root = root + self.paths = paths self.bands = bands or self.all_bands self.cache = cache # Populate the dataset index i = 0 - pathname = os.path.join(root, "**", self.filename_glob) filename_regex = re.compile(self.filename_regex, re.VERBOSE) - for filepath in glob.iglob(pathname, recursive=True): + for filepath in self.files: match = re.match(filename_regex, os.path.basename(filepath)) if match is not None: try: @@ -396,7 +398,10 @@ def __init__( i += 1 if i == 0: - msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`" + msg = ( + f"No {self.__class__.__name__} data was found " + f"in `paths={self.paths!r}'`" + ) if self.bands: msg += f" with `bands={self.bands}`" raise FileNotFoundError(msg) @@ -418,6 +423,32 @@ def __init__( self._crs = cast(CRS, crs) self._res = cast(float, res) + @property + def files(self) -> set[str]: + """A list of all files in the dataset. + + Returns: + All files in the dataset. + + .. versionadded:: 0.5 + """ + # Make iterable + if isinstance(self.paths, str): + paths: Iterable[str] = [self.paths] + else: + paths = self.paths + + # Using set to remove any duplicates if directories are overlapping + files: set[str] = set() + for path in paths: + if os.path.isdir(path): + pathname = os.path.join(path, "**", self.filename_glob) + files |= set(glob.iglob(pathname, recursive=True)) + else: + files.add(path) + + return files + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 1954685ceb0..7fe6428c8ec 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union, cast import matplotlib.pyplot as plt import torch @@ -118,7 +119,7 @@ class GlobBiomass(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, measurement: str = "agb", @@ -129,7 +130,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -141,11 +142,14 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if dataset is missing or checksum fails AssertionError: if measurement argument is invalid, or not a str + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.checksum = checksum assert isinstance(measurement, str), "Measurement argument must be a str." @@ -161,7 +165,7 @@ def __init__( self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -206,12 +210,12 @@ def _verify(self) -> None: RuntimeError: if dataset is missing or checksum fails """ # Check if the extracted file already exists - pathname = os.path.join(self.root, self.filename_glob) - if glob.glob(pathname): + if self.files: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.zipfile_glob) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile_glob) if glob.glob(pathname): for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) @@ -221,7 +225,7 @@ def _verify(self) -> None: return raise RuntimeError( - f"Dataset not found in `root={self.root}` " + f"Dataset not found in `root={self.paths}` " "either specify a different `root` directory or make sure you " "have manually downloaded the dataset as suggested in the documentation." ) diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 9de48f9a340..8fa36daad1d 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union, cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -91,7 +91,7 @@ class L7Irish(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = CRS.from_epsg(3857), res: Optional[float] = None, bands: Sequence[str] = all_bands, @@ -103,7 +103,7 @@ def __init__( """Initialize a new L7Irish instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to EPSG:3857) res: resolution of the dataset in units of CRS @@ -118,15 +118,18 @@ def __init__( Raises: RuntimeError: if ``download=False`` and data is not found, or checksums don't match + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.download = download self.checksum = checksum self._verify() super().__init__( - root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache + paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache ) def _verify(self) -> None: @@ -136,12 +139,12 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted files already exist - pathname = os.path.join(self.root, "**", self.filename_glob) - for fname in glob.iglob(pathname, recursive=True): + if self.files: return # Check if the tar.gz files have already been downloaded - pathname = os.path.join(self.root, "*.tar.gz") + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "*.tar.gz") if glob.glob(pathname): self._extract() return @@ -149,7 +152,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -162,12 +165,13 @@ def _download(self) -> None: """Download the dataset.""" for biome, md5 in self.md5s.items(): download_url( - self.url.format(biome), self.root, md5=md5 if self.checksum else None + self.url.format(biome), self.paths, md5=md5 if self.checksum else None ) def _extract(self) -> None: """Extract the dataset.""" - pathname = os.path.join(self.root, "*.tar.gz") + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "*.tar.gz") for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 105da4042c3..01937dd7a69 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union, cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -90,7 +90,7 @@ class L8Biome(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]], crs: Optional[CRS] = CRS.from_epsg(3857), res: Optional[float] = None, bands: Sequence[str] = all_bands, @@ -102,7 +102,7 @@ def __init__( """Initialize a new L8Biome instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to EPSG:3857) res: resolution of the dataset in units of CRS @@ -117,15 +117,18 @@ def __init__( Raises: RuntimeError: if ``download=False`` and data is not found, or checksums don't match + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.download = download self.checksum = checksum self._verify() super().__init__( - root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache + paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache ) def _verify(self) -> None: @@ -135,12 +138,12 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted files already exist - pathname = os.path.join(self.root, "**", self.filename_glob) - for fname in glob.iglob(pathname, recursive=True): + if self.files: return # Check if the tar.gz files have already been downloaded - pathname = os.path.join(self.root, "*.tar.gz") + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "*.tar.gz") if glob.glob(pathname): self._extract() return @@ -148,7 +151,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -161,12 +164,13 @@ def _download(self) -> None: """Download the dataset.""" for biome, md5 in self.md5s.items(): download_url( - self.url.format(biome), self.root, md5=md5 if self.checksum else None + self.url.format(biome), self.paths, md5=md5 if self.checksum else None ) def _extract(self) -> None: """Extract the dataset.""" - pathname = os.path.join(self.root, "*.tar.gz") + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "*.tar.gz") for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 79f1b110193..1a5543dbdd8 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -222,20 +222,20 @@ def __init__( """Initialize a new LandCover.ai NonGeo dataset instance. Args: - root: root directory where dataset can be found - crs: :term:`coordinate reference system (CRS)` to warp to - (defaults to the CRS of the first file found) - res: resolution of the dataset in units of CRS - (defaults to the resolution of the first file found) - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - cache: if True, cache file handle to speed up repeated sampling - download: if True, download dataset and store it in the root directory - checksum: if True, check the MD5 of the downloaded files (may be slow) + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match """ LandCoverAIBase.__init__(self, root, download, checksum) RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 72174114687..097b3a36a94 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -4,8 +4,8 @@ """Landsat datasets.""" import abc -from collections.abc import Sequence -from typing import Any, Callable, Optional +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -58,7 +58,7 @@ def default_bands(self) -> list[str]: def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, @@ -68,7 +68,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -79,12 +79,15 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ bands = bands or self.default_bands self.filename_glob = self.filename_glob.format(bands[0]) - super().__init__(root, crs, res, bands, transforms, cache) + super().__init__(paths, crs, res, bands, transforms, cache) def plot( self, diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index b7c1eae565f..da6479e289e 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt import torch @@ -106,7 +107,7 @@ class NLCD(RasterDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, years: list[int] = [2019], @@ -119,7 +120,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -135,8 +136,11 @@ def __init__( Raises: AssertionError: if ``years`` or ``classes`` are invalid - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ assert set(years) <= self.md5s.keys(), ( "NLCD data product only exists for the following years: " @@ -147,7 +151,7 @@ def __init__( ), f"Only the following classes are valid: {list(self.cmap.keys())}." assert 0 in classes, "Classes must include the background class: 0" - self.root = root + self.paths = paths self.years = years self.classes = classes self.download = download @@ -157,7 +161,7 @@ def __init__( self._verify() - super().__init__(root, crs, res, transforms=transforms, cache=cache) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) # Map chosen classes to ordinal numbers, all others mapped to background class for v, k in enumerate(self.classes): @@ -187,23 +191,15 @@ def _verify(self) -> None: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ # Check if the extracted files already exist - exists = [] - for year in self.years: - filename_year = self.filename_glob.replace("*", str(year), 1) - pathname = os.path.join(self.root, "**", filename_year) - if glob.glob(pathname, recursive=True): - exists.append(True) - else: - exists.append(False) - - if all(exists): + if self.files: return # Check if the zip files have already been downloaded exists = [] for year in self.years: zipfile_year = self.zipfile_glob.replace("*", str(year), 1) - pathname = os.path.join(self.root, "**", zipfile_year) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "**", zipfile_year) if glob.glob(pathname, recursive=True): exists.append(True) self._extract() @@ -216,7 +212,7 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " + f"Dataset not found in `root={self.paths}` and `download=False`, " "either specify a different `root` directory or use `download=True` " "to automatically download the dataset." ) @@ -230,7 +226,7 @@ def _download(self) -> None: for year in self.years: download_url( self.url.format(year), - self.root, + self.paths, md5=self.md5s[year] if self.checksum else None, ) @@ -238,8 +234,9 @@ def _extract(self) -> None: """Extract the dataset.""" for year in self.years: zipfile_name = self.zipfile_glob.replace("*", str(year), 1) - pathname = os.path.join(self.root, "**", zipfile_name) - extract_archive(glob.glob(pathname, recursive=True)[0], self.root) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, "**", zipfile_name) + extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) def plot( self, diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 33e5ed43b07..1c4e2423482 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -3,8 +3,8 @@ """Sentinel datasets.""" -from collections.abc import Sequence -from typing import Any, Callable, Optional +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt import torch @@ -140,7 +140,7 @@ class Sentinel1(Sentinel): def __init__( self, - root: str = "data", + paths: Union[str, list[str]] = "data", crs: Optional[CRS] = None, res: float = 10, bands: Sequence[str] = ["VV", "VH"], @@ -150,7 +150,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -162,7 +162,10 @@ def __init__( Raises: AssertionError: if ``bands`` is invalid - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ assert len(bands) > 0, "'bands' cannot be an empty list" assert len(bands) == len(set(bands)), "'bands' contains duplicate bands" @@ -184,7 +187,7 @@ def __init__( self.filename_glob = self.filename_glob.format(bands[0]) - super().__init__(root, crs, res, bands, transforms, cache) + super().__init__(paths, crs, res, bands, transforms, cache) def plot( self, @@ -293,7 +296,7 @@ class Sentinel2(Sentinel): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: float = 10, bands: Optional[Sequence[str]] = None, @@ -303,7 +306,7 @@ def __init__( """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -314,13 +317,16 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: - FileNotFoundError: if no files are found in ``root`` + FileNotFoundError: if no files are found in ``paths`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths* """ bands = bands or self.all_bands self.filename_glob = self.filename_glob.format(bands[0]) self.filename_regex = self.filename_regex.format(res) - super().__init__(root, crs, res, bands, transforms, cache) + super().__init__(paths, crs, res, bands, transforms, cache) def plot( self,