From cb6cce385c82c98df9ae6e42fe6764d2f445bb46 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 2 Apr 2021 07:34:11 -0700 Subject: [PATCH] [fbsync] add tests for (Dataset|Image)Folder (#3477) Summary: * add tests for (Dataset|Image)Folder * lint * remove old tests * cleanup * more cleanup * adapt tests * fix make_dataset * remove powerset * readd import Reviewed By: fmassa Differential Revision: D27433923 fbshipit-source-id: 6ea3fb79f41e255045a642dcadedd8fa813e9dcc --- test/test_datasets.py | 151 ++++++++++++++++++++------------- torchvision/datasets/folder.py | 2 +- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 2dde215e32b..db80b55a90f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1): class Tester(DatasetTestcase): - def test_imagefolder(self): - # TODO: create the fake data on-the-fly - FAKEDATA_DIR = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata') - - with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: - classes = sorted(['a', 'b']) - class_a_image_files = [ - os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png') - ] - class_b_image_files = [ - os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png') - ] - dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x) - - # test if all classes are present - self.assertEqual(classes, sorted(dataset.classes)) - - # test if combination of classes and class_to_index functions correctly - for cls in classes: - self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) - - # test if all images were detected correctly - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files] - imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - # test if the datasets outputs all images correctly - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - # redo all tests with specified valid image files - dataset = torchvision.datasets.ImageFolder( - root, loader=lambda x: x, is_valid_file=lambda x: '3' in x) - self.assertEqual(classes, sorted(dataset.classes)) - - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files - if '3' in img_file] - imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files - if '3' in img_file] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - def test_imagefolder_empty(self): - with get_tmp_dir() as root: - with self.assertRaises(FileNotFoundError): - torchvision.datasets.ImageFolder(root, loader=lambda x: x) - - with self.assertRaises(FileNotFoundError): - torchvision.datasets.ImageFolder( - root, loader=lambda x: x, is_valid_file=lambda x: False - ) - @mock.patch('torchvision.datasets.SVHN._check_integrity') @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") def test_svhn(self, mock_check): @@ -1673,5 +1612,95 @@ def test_num_examples_test50k(self): self.assertEqual(len(dataset), info["num_examples"] - 10000) +class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.DatasetFolder + + # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader + # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method. + FEATURE_TYPES = (str, int) + + _IMAGE_EXTENSIONS = ("jpg", "png") + _VIDEO_EXTENSIONS = ("avi", "mp4") + _EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS) + + # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. + # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the + # 'test_is_valid_file()' method. + DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) + ADDITIONAL_CONFIGS = ( + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]), + dict(extensions=_IMAGE_EXTENSIONS), + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]), + dict(extensions=_VIDEO_EXTENSIONS), + ) + + def dataset_args(self, tmpdir, config): + return tmpdir, lambda x: x + + def inject_fake_data(self, tmpdir, config): + extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) + + num_examples_total = 0 + classes = [] + for ext, cls in zip(self._EXTENSIONS, string.ascii_letters): + if ext not in extensions: + continue + + create_example_folder = ( + datasets_utils.create_image_folder + if ext in self._IMAGE_EXTENSIONS + else datasets_utils.create_video_folder + ) + + num_examples = torch.randint(1, 3, size=()).item() + create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) + + num_examples_total += num_examples + classes.append(cls) + + return dict(num_examples=num_examples_total, classes=classes) + + def _file_name_fn(self, cls, ext, idx): + return f"{cls}_{idx}.{ext}" + + def _is_valid_file_to_extensions(self, is_valid_file): + return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")} + + @datasets_utils.test_all_configs + def test_is_valid_file(self, config): + extensions = config.pop("extensions") + # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the + # DEFAULT_CONFIG. + with self.create_dataset( + config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + ) as (dataset, info): + self.assertEqual(len(dataset), info["num_examples"]) + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + self.assertSequenceEqual(dataset.classes, info["classes"]) + + +class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ImageFolder + + def inject_fake_data(self, tmpdir, config): + num_examples_total = 0 + classes = ("a", "b") + for cls in classes: + num_examples = torch.randint(1, 3, size=()).item() + num_examples_total += num_examples + + datasets_utils.create_image_folder(tmpdir, cls, lambda idx: f"{cls}_{idx}.png", num_examples) + + return dict(num_examples=num_examples_total, classes=classes) + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + self.assertSequenceEqual(dataset.classes, info["classes"]) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index fb4861e637a..d121bad7a19 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -129,7 +129,7 @@ def is_valid_file(x: str) -> bool: if target_class not in available_classes: available_classes.add(target_class) - empty_classes = available_classes - set(class_to_idx.keys()) + empty_classes = set(class_to_idx.keys()) - available_classes if empty_classes: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: