From 77bb592f33012f640f3f7db637c12d2609fec498 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 29 Sep 2021 07:51:18 +0200 Subject: [PATCH] cleanup prototype datasets (#4471) * cleanup image folder * make shuffling mandatory * rename parameter in home() function * don't show builtin list * make categories optional in dataset info * use pseudo-infinite buffer size for shuffler Co-authored-by: Francisco Massa --- torchvision/prototype/datasets/__init__.py | 2 +- torchvision/prototype/datasets/_api.py | 6 +++--- torchvision/prototype/datasets/_folder.py | 8 ++------ torchvision/prototype/datasets/_home.py | 12 ++++++------ torchvision/prototype/datasets/decoder.py | 6 +++--- torchvision/prototype/datasets/utils/_dataset.py | 10 +++++----- torchvision/prototype/datasets/utils/_internal.py | 4 ++++ 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 18217b4bd3d..8661eb65cc5 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -2,5 +2,5 @@ from . import decoder, utils # Load this last, since some parts depend on the above being loaded first -from ._api import register, list, info, load +from ._api import register, _list as list, info, load from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 9f72baba1f1..29dce26dd0c 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset -def list() -> List[str]: +# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' +def _list() -> List[str]: return sorted(DATASETS.keys()) @@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo: def load( name: str, *, - shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, split: str = "train", **options: Any, @@ -55,4 +55,4 @@ def load( config = dataset.info.make_config(split=split, **options) root = home() / name - return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder) + return dataset.to_datapipe(root, config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index afdb78bc43c..5626f68650f 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -10,13 +10,11 @@ from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter from torchvision.prototype.datasets.decoder import pil +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE __all__ = ["from_data_folder", "from_image_folder"] -# pseudo-infinite buffer size until a true infinite buffer is supported -INFINITE = 1_000_000_000 - def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: rel_path = pathlib.Path(path).relative_to(root) @@ -45,7 +43,6 @@ def _collate_and_decode_data( def from_data_folder( root: Union[str, pathlib.Path], *, - shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE), decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, @@ -55,8 +52,7 @@ def from_data_folder( masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) - if shuffler: - dp = shuffler(dp) + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), diff --git a/torchvision/prototype/datasets/_home.py b/torchvision/prototype/datasets/_home.py index b1d7751192b..535d35294b9 100644 --- a/torchvision/prototype/datasets/_home.py +++ b/torchvision/prototype/datasets/_home.py @@ -7,14 +7,14 @@ HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision" -def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path: +def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path: global HOME - if home is not None: - HOME = pathlib.Path(home).expanduser().resolve() + if root is not None: + HOME = pathlib.Path(root).expanduser().resolve() return HOME - home = os.getenv("TORCHVISION_DATASETS_HOME") - if home is not None: - return pathlib.Path(home) + root = os.getenv("TORCHVISION_DATASETS_HOME") + if root is not None: + return pathlib.Path(root) return HOME diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index cd41ceeac7a..4c10cff1035 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -1,12 +1,12 @@ import io -import numpy as np import PIL.Image import torch +from torchvision.transforms.functional import pil_to_tensor + __all__ = ["pil"] def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor: - image = PIL.Image.open(file).convert(mode.upper()) - return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1)) + return pil_to_tensor(PIL.Image.open(file).convert(mode.upper())) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 78ccc0a0186..b3cf53afc8d 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -98,7 +98,7 @@ def __init__( self, name: str, *, - categories: Union[int, Sequence[str], str, pathlib.Path], + categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, homepage: Optional[str] = None, license: Optional[str] = None, @@ -106,7 +106,9 @@ def __init__( ) -> None: self.name = name.lower() - if isinstance(categories, int): + if categories is None: + categories = [] + elif isinstance(categories, int): categories = [str(label) for label in range(categories)] elif isinstance(categories, (str, pathlib.Path)): with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: @@ -198,7 +200,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: pass @@ -208,7 +209,6 @@ def to_datapipe( root: Union[str, pathlib.Path], *, config: Optional[DatasetConfig] = None, - shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, ) -> IterDataPipe[Dict[str, Any]]: if not config: @@ -217,4 +217,4 @@ def to_datapipe( resource_dps = [ resource.to_datapipe(root) for resource in self.resources(config) ] - return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder) + return self._make_datapipe(resource_dps, config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index f378c5da29d..ad4f70145d5 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -4,10 +4,14 @@ __all__ = [ + "INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", ] +# pseudo-infinite until a true infinite buffer is supported by all datapipes +INFINITE_BUFFER_SIZE = 1_000_000_000 + def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: if len(seq) == 1: