From 0b07d608cb05ad1753309e66e21df7eb0ed50764 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 1 Mar 2021 14:56:28 +0100 Subject: [PATCH 1/9] add tests for (Dataset|Image)Folder --- test/datasets_utils.py | 14 +++++++ test/test_datasets.py | 90 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 374ab48b4b7..59571277e1b 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -723,3 +723,17 @@ def create_random_string(length: int, *digits: str) -> str: digits = "".join(itertools.chain(*digits)) return "".join(random.choice(digits) for _ in range(length)) + + +def powerset(iterable): + """Create the powerset from given iterable. + + E.g.: powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) + + This function is taken from the + `itertools recipes `_. + + """ + s = list(iterable) + return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1)) + diff --git a/test/test_datasets.py b/test/test_datasets.py index 59a292e319b..5af5184942e 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1275,5 +1275,95 @@ def test_not_found_or_corrupted(self): self.skipTest("The data is generated at creation and thus cannot be non-existent or corrupted.") +class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.DatasetFolder + + # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader + # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method. + FEATURE_TYPES = (str, int) + + _IMAGE_EXTENSIONS = ("jpg", "png") + _VIDEO_EXTENSIONS = ("avi", "mp4") + _EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS) + + # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. + # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the + # 'test_is_valid_file()' method. + # This is also the reason we are forced to overwrite the default tests that are not tested against all configs, + # i.e. 'test_not_found_or_corrupted()' and 'test_smoke()'. + CONFIGS = datasets_utils.combinations_grid(extensions=datasets_utils.powerset(_EXTENSIONS)) + + def dataset_args(self, tmpdir, config): + return tmpdir, lambda x: x + + def inject_fake_data(self, tmpdir, config): + num_examples = {} + classes = {} + for ext, cls in zip(self._EXTENSIONS, string.ascii_letters): + create_example_folder = ( + datasets_utils.create_image_folder + if ext in self._IMAGE_EXTENSIONS + else datasets_utils.create_video_folder + ) + + num_examples[ext] = torch.randint(1, 3, size=()).item() + classes[ext] = cls + + create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples[ext]) + + extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) + return dict(num_examples=sum(num_examples[ext] for ext in extensions), classes=[classes[ext] for ext in extensions]) + + def _file_name_fn(self, cls, ext, idx): + return f"{cls}_{idx}.{ext}" + + def _is_valid_file_to_extensions(self, is_valid_file): + return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")} + + def test_not_found_or_corrupted(self): + with self.assertRaises((FileNotFoundError, RuntimeError)): + with self.create_dataset(inject_fake_data=False, extensions=self._EXTENSIONS): + pass + + def test_smoke(self): + with self.create_dataset(extensions=self._EXTENSIONS) as (dataset, _): + self.assertIsInstance(dataset, torchvision.datasets.VisionDataset) + + def test_is_valid_file(self): + for config in self.CONFIGS: + config = config.copy() + extensions = config.pop("extensions") + with self.subTest(extensions=extensions): + with self.create_dataset( + config, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + ) as (dataset, info): + self.assertEqual(len(dataset), info["num_examples"]) + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + self.assertSequenceEqual(dataset.classes, info["classes"]) + + +class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ImageFolder + + def inject_fake_data(self, tmpdir, config): + num_examples_total = 0 + classes = ("a", "b") + for cls in classes: + num_examples = torch.randint(1, 3, size=()).item() + num_examples_total += num_examples + + datasets_utils.create_image_folder(tmpdir, cls, lambda idx: f"{cls}_{idx}.png", num_examples) + + return dict(num_examples=num_examples_total, classes=classes) + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + self.assertSequenceEqual(dataset.classes, info["classes"]) + + if __name__ == "__main__": unittest.main() From bc1d2d1b5565c38824bab73829aec91416ab8f30 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 1 Mar 2021 15:02:49 +0100 Subject: [PATCH 2/9] lint --- test/datasets_utils.py | 3 +-- test/test_datasets.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 59571277e1b..503eebcee37 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -735,5 +735,4 @@ def powerset(iterable): """ s = list(iterable) - return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1)) - + return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1)) diff --git a/test/test_datasets.py b/test/test_datasets.py index 5af5184942e..5d4bf5e6aa9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1312,7 +1312,9 @@ def inject_fake_data(self, tmpdir, config): create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples[ext]) extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) - return dict(num_examples=sum(num_examples[ext] for ext in extensions), classes=[classes[ext] for ext in extensions]) + return dict( + num_examples=sum(num_examples[ext] for ext in extensions), classes=[classes[ext] for ext in extensions], + ) def _file_name_fn(self, cls, ext, idx): return f"{cls}_{idx}.{ext}" From c58cbd3b65cd37e24691c002322cb120235f567a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 2 Mar 2021 12:49:40 +0100 Subject: [PATCH 3/9] remove old tests --- test/test_datasets.py | 61 ------------------------------------------- 1 file changed, 61 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 5d4bf5e6aa9..ccdcefdf49f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1): class Tester(DatasetTestcase): - def test_imagefolder(self): - # TODO: create the fake data on-the-fly - FAKEDATA_DIR = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata') - - with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: - classes = sorted(['a', 'b']) - class_a_image_files = [ - os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png') - ] - class_b_image_files = [ - os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png') - ] - dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x) - - # test if all classes are present - self.assertEqual(classes, sorted(dataset.classes)) - - # test if combination of classes and class_to_index functions correctly - for cls in classes: - self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) - - # test if all images were detected correctly - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files] - imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - # test if the datasets outputs all images correctly - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - # redo all tests with specified valid image files - dataset = torchvision.datasets.ImageFolder( - root, loader=lambda x: x, is_valid_file=lambda x: '3' in x) - self.assertEqual(classes, sorted(dataset.classes)) - - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files - if '3' in img_file] - imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files - if '3' in img_file] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - def test_imagefolder_empty(self): - with get_tmp_dir() as root: - with self.assertRaises(RuntimeError): - torchvision.datasets.ImageFolder(root, loader=lambda x: x) - - with self.assertRaises(RuntimeError): - torchvision.datasets.ImageFolder( - root, loader=lambda x: x, is_valid_file=lambda x: False - ) - @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract): num_examples = 30 From 70f261a21be3ebab158a08cf1400836687cdaf98 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 13:44:18 +0100 Subject: [PATCH 4/9] cleanup --- test/test_datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 74484d0d33b..9569ccf2181 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -59,7 +59,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1): class Tester(DatasetTestcase): @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_mnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_mnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) @@ -68,7 +69,8 @@ def test_mnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_kmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_kmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) @@ -77,7 +79,8 @@ def test_kmnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_fashionmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_fashionmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) @@ -85,7 +88,7 @@ def test_fashionmnist(self, mock_download_extract): img, target = dataset[0] self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows') def test_cityscapes(self): with cityscapes_root() as root: From 0f9265d95e4f504f321ccb34cb338afefa40cd95 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 13:44:53 +0100 Subject: [PATCH 5/9] more cleanup --- test/test_datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index 9569ccf2181..bceb8bbbbfd 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1350,6 +1350,7 @@ def test_feature_types(self, config): finally: self.FEATURE_TYPES = feature_types + class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Flickr8k From aba227f31e8cd03fbdd73ded1544e61698a2e107 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 17:50:17 +0100 Subject: [PATCH 6/9] adapt tests --- test/test_datasets.py | 58 +++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index bceb8bbbbfd..b5cef3f659f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1452,32 +1452,39 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the # 'test_is_valid_file()' method. - # This is also the reason we are forced to overwrite the default tests that are not tested against all configs, - # i.e. 'test_not_found_or_corrupted()' and 'test_smoke()'. - CONFIGS = datasets_utils.combinations_grid(extensions=datasets_utils.powerset(_EXTENSIONS)) + DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) + ADDITIONAL_CONFIGS = ( + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]), + dict(extensions=_IMAGE_EXTENSIONS), + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]), + dict(extensions=_VIDEO_EXTENSIONS), + ) def dataset_args(self, tmpdir, config): return tmpdir, lambda x: x def inject_fake_data(self, tmpdir, config): - num_examples = {} - classes = {} + extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) + + num_examples_total = 0 + classes = [] for ext, cls in zip(self._EXTENSIONS, string.ascii_letters): + if ext not in extensions: + continue + create_example_folder = ( datasets_utils.create_image_folder if ext in self._IMAGE_EXTENSIONS else datasets_utils.create_video_folder ) - num_examples[ext] = torch.randint(1, 3, size=()).item() - classes[ext] = cls + num_examples = torch.randint(1, 3, size=()).item() + create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) - create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples[ext]) + num_examples_total += num_examples + classes.append(cls) - extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) - return dict( - num_examples=sum(num_examples[ext] for ext in extensions), classes=[classes[ext] for ext in extensions], - ) + return dict(num_examples=num_examples_total, classes=classes) def _file_name_fn(self, cls, ext, idx): return f"{cls}_{idx}.{ext}" @@ -1485,24 +1492,15 @@ def _file_name_fn(self, cls, ext, idx): def _is_valid_file_to_extensions(self, is_valid_file): return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")} - def test_not_found_or_corrupted(self): - with self.assertRaises((FileNotFoundError, RuntimeError)): - with self.create_dataset(inject_fake_data=False, extensions=self._EXTENSIONS): - pass - - def test_smoke(self): - with self.create_dataset(extensions=self._EXTENSIONS) as (dataset, _): - self.assertIsInstance(dataset, torchvision.datasets.VisionDataset) - - def test_is_valid_file(self): - for config in self.CONFIGS: - config = config.copy() - extensions = config.pop("extensions") - with self.subTest(extensions=extensions): - with self.create_dataset( - config, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions - ) as (dataset, info): - self.assertEqual(len(dataset), info["num_examples"]) + @datasets_utils.test_all_configs + def test_is_valid_file(self, config): + extensions = config.pop("extensions") + # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the + # DEFAULT_CONFIG. + with self.create_dataset( + config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + ) as (dataset, info): + self.assertEqual(len(dataset), info["num_examples"]) @datasets_utils.test_all_configs def test_classes(self, config): From 465ab83fed04e13ad78d7fee22dff09367b4f764 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 17:50:36 +0100 Subject: [PATCH 7/9] fix make_dataset --- torchvision/datasets/folder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index fb4861e637a..d121bad7a19 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -129,7 +129,7 @@ def is_valid_file(x: str) -> bool: if target_class not in available_classes: available_classes.add(target_class) - empty_classes = available_classes - set(class_to_idx.keys()) + empty_classes = set(class_to_idx.keys()) - available_classes if empty_classes: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: From 74a83b8bbf789f2d56959d242bba7333e489f1a8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 17:53:52 +0100 Subject: [PATCH 8/9] remove powerset --- test/datasets_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 9fc544dc0ca..dad6af6544d 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -1,4 +1,3 @@ -import collections.abc import contextlib import functools import importlib @@ -849,16 +848,3 @@ def create_random_string(length: int, *digits: str) -> str: digits = "".join(itertools.chain(*digits)) return "".join(random.choice(digits) for _ in range(length)) - - -def powerset(iterable): - """Create the powerset from given iterable. - - E.g.: powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) - - This function is taken from the - `itertools recipes `_. - - """ - s = list(iterable) - return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1)) From ef5efff42001751a550d73a3510ce76f466c3aac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Mar 2021 17:54:18 +0100 Subject: [PATCH 9/9] readd import --- test/datasets_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index dad6af6544d..8ba55c21f60 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -1,3 +1,4 @@ +import collections.abc import contextlib import functools import importlib