From be8978b6d91d9f0a7cda6c4b0dc8d511dac585fb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Feb 2021 07:44:30 +0100 Subject: [PATCH] remove old CIFAR tests and fake data generation --- test/fakedata_generation.py | 55 ------------------------------------- test/test_datasets.py | 34 +---------------------- 2 files changed, 1 insertion(+), 88 deletions(-) diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index dac415df110..4249dedd54e 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -88,61 +88,6 @@ def _make_label_file(filename, num_images): yield tmp_dir -@contextlib.contextmanager -def cifar_root(version): - def _get_version_params(version): - if version == 'CIFAR10': - return { - 'base_folder': 'cifar-10-batches-py', - 'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)], - 'test_file': 'test_batch', - 'target_key': 'labels', - 'meta_file': 'batches.meta', - 'classes_key': 'label_names', - } - elif version == 'CIFAR100': - return { - 'base_folder': 'cifar-100-python', - 'train_files': ['train'], - 'test_file': 'test', - 'target_key': 'fine_labels', - 'meta_file': 'meta', - 'classes_key': 'fine_label_names', - } - else: - raise ValueError - - def _make_pickled_file(obj, file): - with open(file, 'wb') as fh: - pickle.dump(obj, fh, 2) - - def _make_data_file(file, target_key): - obj = { - 'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8), - target_key: [0] - } - _make_pickled_file(obj, file) - - def _make_meta_file(file, classes_key): - obj = { - classes_key: ['fakedata'], - } - _make_pickled_file(obj, file) - - params = _get_version_params(version) - with get_tmp_dir() as root: - base_folder = os.path.join(root, params['base_folder']) - os.mkdir(base_folder) - - for file in list(params['train_files']) + [params['test_file']]: - _make_data_file(os.path.join(base_folder, file), params['target_key']) - - _make_meta_file(os.path.join(base_folder, params['meta_file']), - params['classes_key']) - - yield root - - @contextlib.contextmanager def imagenet_root(): import scipy.io as sio diff --git a/test/test_datasets.py b/test/test_datasets.py index 37651ae7614..47ff659d10d 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -10,7 +10,7 @@ import torchvision from torchvision.datasets import utils from common_utils import get_tmp_dir -from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ +from fakedata_generation import mnist_root, imagenet_root, \ cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root import xml.etree.ElementTree as ET from urllib.request import Request, urlopen @@ -171,38 +171,6 @@ def test_widerface(self, mock_check_integrity): img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) - @mock.patch('torchvision.datasets.cifar.check_integrity') - @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') - def test_cifar10(self, mock_ext_check, mock_int_check): - mock_ext_check.return_value = True - mock_int_check.return_value = True - with cifar_root('CIFAR10') as root: - dataset = torchvision.datasets.CIFAR10(root, train=True, download=True) - self.generic_classification_dataset_test(dataset, num_images=5) - img, target = dataset[0] - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - - dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) - self.generic_classification_dataset_test(dataset) - img, target = dataset[0] - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - - @mock.patch('torchvision.datasets.cifar.check_integrity') - @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') - def test_cifar100(self, mock_ext_check, mock_int_check): - mock_ext_check.return_value = True - mock_int_check.return_value = True - with cifar_root('CIFAR100') as root: - dataset = torchvision.datasets.CIFAR100(root, train=True, download=True) - self.generic_classification_dataset_test(dataset) - img, target = dataset[0] - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - - dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) - self.generic_classification_dataset_test(dataset) - img, target = dataset[0] - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_cityscapes(self): with cityscapes_root() as root: