From 1dbc6f9b3f65977bbae90d6b629f05155aeb3780 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 10 Feb 2021 01:57:55 -0800 Subject: [PATCH] Add tests for the STL10 dataset (#3345) Summary: * extract some functionality from places365 fakedata for common use * add a common DatasetTestcase * add fakedata generation and tests for STL10 * lint Reviewed By: fmassa Differential Revision: D26341418 fbshipit-source-id: 05f8a60c986c32f64339197ea377efc6c4d5b238 Co-authored-by: Francisco Massa --- test/fakedata_generation.py | 165 ++++++++++++++++++++++++++++-------- test/test_datasets.py | 84 +++++++++++++++++- 2 files changed, 209 insertions(+), 40 deletions(-) diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index 29a62d79a61..020b073febb 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -13,6 +13,48 @@ import unittest.mock import hashlib from distutils import dir_util +import re + + +def mock_class_attribute(stack, target, new): + mock = unittest.mock.patch(target, new_callable=unittest.mock.PropertyMock, return_value=new) + stack.enter_context(mock) + return mock + + +def compute_md5(file): + with open(file, "rb") as fh: + return hashlib.md5(fh.read()).hexdigest() + + +def make_tar(root, name, *files, compression=None): + ext = ".tar" + mode = "w" + if compression is not None: + ext = f"{ext}.{compression}" + mode = f"{mode}:{compression}" + + name = os.path.splitext(name)[0] + ext + archive = os.path.join(root, name) + + with tarfile.open(archive, mode) as fh: + for file in files: + fh.add(os.path.join(root, file), arcname=file) + + return name, compute_md5(archive) + + +def clean_dir(root, *keep): + pattern = re.compile(f"({f')|('.join(keep)})") + for file_or_dir in os.listdir(root): + if pattern.search(file_or_dir): + continue + + file_or_dir = os.path.join(root, file_or_dir) + if os.path.isfile(file_or_dir): + os.remove(file_or_dir) + else: + dir_util.remove_tree(file_or_dir) @contextlib.contextmanager @@ -385,7 +427,7 @@ def ucf101_root(): @contextlib.contextmanager -def places365_root(split="train-standard", small=False, extract_images=True): +def places365_root(split="train-standard", small=False): VARIANTS = { "train-standard": "standard", "train-challenge": "challenge", @@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True): def mock_target(attr, partial="torchvision.datasets.places365.Places365"): return f"{partial}.{attr}" - def mock_class_attribute(stack, attr, new): - mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new) - stack.enter_context(mock) - return mock - - def compute_md5(file): - with open(file, "rb") as fh: - return hashlib.md5(fh.read()).hexdigest() - def make_txt(root, name, seq): file = os.path.join(root, name) with open(file, "w") as fh: @@ -451,37 +484,20 @@ def make_image(file, size): os.makedirs(os.path.dirname(file), exist_ok=True) PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file) - def make_tar(root, name, *files, remove_files=True): - name = f"{os.path.splitext(name)[0]}.tar" - archive = os.path.join(root, name) - - with tarfile.open(archive, "w") as fh: - for file in files: - fh.add(os.path.join(root, file), arcname=file) - - if remove_files: - for file in [os.path.join(root, file) for file in files]: - if os.path.isdir(file): - dir_util.remove_tree(file) - else: - os.remove(file) - - return name, compute_md5(archive) - def make_devkit_archive(stack, root, split): archive = DEVKITS[split] files = [] meta = make_categories_txt(root, CATEGORIES) - mock_class_attribute(stack, "_CATEGORIES_META", meta) + mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta) files.append(meta[0]) meta = {split: make_file_list_txt(root, FILE_LISTS[split])} - mock_class_attribute(stack, "_FILE_LIST_META", meta) + mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta) files.extend([item[0] for item in meta.values()]) meta = {VARIANTS[split]: make_tar(root, archive, *files)} - mock_class_attribute(stack, "_DEVKIT_META", meta) + mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta) def make_images_archive(stack, root, split, small): archive, folder_default, folder_renamed = IMAGES[(split, small)] @@ -493,7 +509,7 @@ def make_images_archive(stack, root, split, small): make_image(os.path.join(root, folder_default, image), image_size) meta = {(split, small): make_tar(root, archive, folder_default)} - mock_class_attribute(stack, "_IMAGES_META", meta) + mock_class_attribute(stack, mock_target("_IMAGES_META"), meta) return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)] @@ -501,12 +517,89 @@ def make_images_archive(stack, root, split, small): make_devkit_archive(stack, root, split) class_to_idx = dict(CATEGORIES_CONTENT) classes = list(class_to_idx.keys()) + data = {"class_to_idx": class_to_idx, "classes": classes} + data["imgs"] = make_images_archive(stack, root, split, small) - if extract_images: - data["imgs"] = make_images_archive(stack, root, split, small) - else: - stack.enter_context(unittest.mock.patch(mock_target("download_images"))) - data["imgs"] = None + clean_dir(root, ".tar$") + + yield root, data + + +@contextlib.contextmanager +def stl10_root(_extracted=False): + CLASS_NAMES = ("airplane", "bird") + ARCHIVE_NAME = "stl10_binary" + NUM_FOLDS = 10 + + def mock_target(attr, partial="torchvision.datasets.stl10.STL10"): + return f"{partial}.{attr}" + + def make_binary_file(num_elements, root, name): + file = os.path.join(root, name) + np.zeros(num_elements, dtype=np.uint8).tofile(file) + return name, compute_md5(file) + + def make_image_file(num_images, root, name, num_channels=3, height=96, width=96): + return make_binary_file(num_images * num_channels * height * width, root, name) + + def make_label_file(num_images, root, name): + return make_binary_file(num_images, root, name) + + def make_class_names_file(root, name="class_names.txt"): + with open(os.path.join(root, name), "w") as fh: + for name in CLASS_NAMES: + fh.write(f"{name}\n") + + def make_fold_indices_file(root): + offset = 0 + with open(os.path.join(root, "fold_indices.txt"), "w") as fh: + for fold in range(NUM_FOLDS): + line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)]) + fh.write(f"{line}\n") + offset += fold + 1 + + return tuple(range(1, NUM_FOLDS + 1)) + + def make_train_files(stack, root, num_unlabeled_images=1): + num_images_in_fold = make_fold_indices_file(root) + num_train_images = sum(num_images_in_fold) + + train_list = [ + list(make_image_file(num_train_images, root, "train_X.bin")), + list(make_label_file(num_train_images, root, "train_y.bin")), + list(make_image_file(1, root, "unlabeled_X.bin")) + ] + mock_class_attribute(stack, target=mock_target("train_list"), new=train_list) + + return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images) + + def make_test_files(stack, root, num_images=2): + test_list = [ + list(make_image_file(num_images, root, "test_X.bin")), + list(make_label_file(num_images, root, "test_y.bin")), + ] + mock_class_attribute(stack, target=mock_target("test_list"), new=test_list) + + return dict(test=num_images) + + def make_archive(stack, root, name): + archive, md5 = make_tar(root, name, name, compression="gz") + mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5) + return archive + + with contextlib.ExitStack() as stack, get_tmp_dir() as root: + archive_folder = os.path.join(root, ARCHIVE_NAME) + os.mkdir(archive_folder) + + num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder) + num_images_in_split.update(make_test_files(stack, archive_folder)) + + make_class_names_file(archive_folder) + + archive = make_archive(stack, root, ARCHIVE_NAME) + + dir_util.remove_tree(archive_folder) + data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive) yield root, data diff --git a/test/test_datasets.py b/test/test_datasets.py index 7184e755a67..ff8e0281e7c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,3 +1,4 @@ +import contextlib import sys import os import unittest @@ -7,9 +8,10 @@ from PIL import Image from torch._utils_internal import get_file_path_2 import torchvision +from torchvision.datasets import utils from common_utils import get_tmp_dir from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ - cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root + cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root import xml.etree.ElementTree as ET from urllib.request import Request, urlopen import itertools @@ -28,7 +30,7 @@ HAS_PYAV = False -class Tester(unittest.TestCase): +class DatasetTestcase(unittest.TestCase): def generic_classification_dataset_test(self, dataset, num_images=1): self.assertEqual(len(dataset), num_images) img, target = dataset[0] @@ -41,6 +43,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1): self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, PIL.Image.Image)) + +class Tester(DatasetTestcase): def test_imagefolder(self): # TODO: create the fake data on-the-fly FAKEDATA_DIR = get_file_path_2( @@ -354,7 +358,7 @@ def test_places365_devkit_download(self): def test_places365_devkit_no_download(self): for split in ("train-standard", "train-challenge", "val"): with self.subTest(split=split): - with places365_root(split=split, extract_images=False) as places365: + with places365_root(split=split) as places365: root, data = places365 with self.assertRaises(RuntimeError): @@ -383,12 +387,84 @@ def test_places365_images_download_preexisting(self): torchvision.datasets.Places365(root, split=split, small=small, download=True) def test_places365_repr_smoke(self): - with places365_root(extract_images=False) as places365: + with places365_root() as places365: root, data = places365 dataset = torchvision.datasets.Places365(root, download=True) self.assertIsInstance(repr(dataset), str) +class STL10Tester(DatasetTestcase): + @contextlib.contextmanager + def mocked_root(self): + with stl10_root() as (root, data): + yield root, data + + @contextlib.contextmanager + def mocked_dataset(self, pre_extract=False, download=True, **kwargs): + with self.mocked_root() as (root, data): + if pre_extract: + utils.extract_archive(os.path.join(root, data["archive"])) + dataset = torchvision.datasets.STL10(root, download=download, **kwargs) + yield dataset, data + + def test_not_found(self): + with self.assertRaises(RuntimeError): + with self.mocked_dataset(download=False): + pass + + def test_splits(self): + for split in ('train', 'train+unlabeled', 'unlabeled', 'test'): + with self.mocked_dataset(split=split) as (dataset, data): + num_images = sum([data["num_images_in_split"][part] for part in split.split("+")]) + self.generic_classification_dataset_test(dataset, num_images=num_images) + + def test_folds(self): + for fold in range(10): + with self.mocked_dataset(split="train", folds=fold) as (dataset, data): + num_images = data["num_images_in_folds"][fold] + self.assertEqual(len(dataset), num_images) + + def test_invalid_folds1(self): + with self.assertRaises(ValueError): + with self.mocked_dataset(folds=10): + pass + + def test_invalid_folds2(self): + with self.assertRaises(ValueError): + with self.mocked_dataset(folds="0"): + pass + + def test_transforms(self): + expected_image = "image" + expected_target = "target" + + def transform(image): + return expected_image + + def target_transform(target): + return expected_target + + with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _): + actual_image, actual_target = dataset[0] + + self.assertEqual(actual_image, expected_image) + self.assertEqual(actual_target, expected_target) + + def test_unlabeled(self): + with self.mocked_dataset(split="unlabeled") as (dataset, _): + labels = [dataset[idx][1] for idx in range(len(dataset))] + self.assertTrue(all([label == -1 for label in labels])) + + @unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive") + def test_download_preexisting(self, mock): + with self.mocked_dataset(pre_extract=True) as (dataset, data): + mock.assert_not_called() + + def test_repr_smoke(self): + with self.mocked_dataset() as (dataset, _): + self.assertIsInstance(repr(dataset), str) + + if __name__ == '__main__': unittest.main()