Skip to content

Commit

Permalink
cleanup prototype datasets (#4471)
Browse files Browse the repository at this point in the history
* cleanup image folder

* make shuffling mandatory

* rename parameter in home() function

* don't show builtin list

* make categories optional in dataset info

* use pseudo-infinite buffer size for shuffler

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
pmeier and fmassa authored Sep 29, 2021
1 parent 932ca5a commit a068602
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import decoder, utils

# Load this last, since some parts depend on the above being loaded first
from ._api import register, list, info, load
from ._api import register, _list as list, info, load
from ._folder import from_data_folder, from_image_folder
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset


def list() -> List[str]:
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]:
return sorted(DATASETS.keys())


Expand Down Expand Up @@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo:
def load(
name: str,
*,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
split: str = "train",
**options: Any,
Expand All @@ -55,4 +55,4 @@ def load(
config = dataset.info.make_config(split=split, **options)
root = home() / name

return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder)
return dataset.to_datapipe(root, config=config, decoder=decoder)
8 changes: 2 additions & 6 deletions torchvision/prototype/datasets/_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter

from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE


__all__ = ["from_data_folder", "from_image_folder"]

# pseudo-infinite buffer size until a true infinite buffer is supported
INFINITE = 1_000_000_000


def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root)
Expand Down Expand Up @@ -45,7 +43,6 @@ def _collate_and_decode_data(
def from_data_folder(
root: Union[str, pathlib.Path],
*,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE),
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
Expand All @@ -55,8 +52,7 @@ def from_data_folder(
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks)
dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
if shuffler:
dp = shuffler(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp)
return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datasets/_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"


def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME
if home is not None:
HOME = pathlib.Path(home).expanduser().resolve()
if root is not None:
HOME = pathlib.Path(root).expanduser().resolve()
return HOME

home = os.getenv("TORCHVISION_DATASETS_HOME")
if home is not None:
return pathlib.Path(home)
root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None:
return pathlib.Path(root)

return HOME
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import io

import numpy as np
import PIL.Image
import torch

from torchvision.transforms.functional import pil_to_tensor

__all__ = ["pil"]


def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
image = PIL.Image.open(file).convert(mode.upper())
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
return pil_to_tensor(PIL.Image.open(file).convert(mode.upper()))
10 changes: 5 additions & 5 deletions torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,17 @@ def __init__(
self,
name: str,
*,
categories: Union[int, Sequence[str], str, pathlib.Path],
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None,
) -> None:
self.name = name.lower()

if isinstance(categories, int):
if categories is None:
categories = []
elif isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
Expand Down Expand Up @@ -198,7 +200,6 @@ def _make_datapipe(
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
pass
Expand All @@ -208,7 +209,6 @@ def to_datapipe(
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
Expand All @@ -217,4 +217,4 @@ def to_datapipe(
resource_dps = [
resource.to_datapipe(root) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder)
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
4 changes: 4 additions & 0 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@


__all__ = [
"INFINITE_BUFFER_SIZE",
"sequence_to_str",
"add_suggestion",
]

# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000


def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
Expand Down

0 comments on commit a068602

Please sign in to comment.