Skip to content

Commit

Permalink
Add tests for the STL10 dataset (#3345)
Browse files Browse the repository at this point in the history
Summary:
* extract some functionality from places365 fakedata for common use

* add a common DatasetTestcase

* add fakedata generation and tests for STL10

* lint

Reviewed By: fmassa

Differential Revision: D26341418

fbshipit-source-id: 05f8a60c986c32f64339197ea377efc6c4d5b238

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Feb 10, 2021
1 parent 9df3022 commit 1dbc6f9
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 40 deletions.
165 changes: 129 additions & 36 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,48 @@
import unittest.mock
import hashlib
from distutils import dir_util
import re


def mock_class_attribute(stack, target, new):
mock = unittest.mock.patch(target, new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock


def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()


def make_tar(root, name, *files, compression=None):
ext = ".tar"
mode = "w"
if compression is not None:
ext = f"{ext}.{compression}"
mode = f"{mode}:{compression}"

name = os.path.splitext(name)[0] + ext
archive = os.path.join(root, name)

with tarfile.open(archive, mode) as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)

return name, compute_md5(archive)


def clean_dir(root, *keep):
pattern = re.compile(f"({f')|('.join(keep)})")
for file_or_dir in os.listdir(root):
if pattern.search(file_or_dir):
continue

file_or_dir = os.path.join(root, file_or_dir)
if os.path.isfile(file_or_dir):
os.remove(file_or_dir)
else:
dir_util.remove_tree(file_or_dir)


@contextlib.contextmanager
Expand Down Expand Up @@ -385,7 +427,7 @@ def ucf101_root():


@contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True):
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
Expand Down Expand Up @@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True):
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"

def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock

def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()

def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
Expand All @@ -451,37 +484,20 @@ def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)

def make_tar(root, name, *files, remove_files=True):
name = f"{os.path.splitext(name)[0]}.tar"
archive = os.path.join(root, name)

with tarfile.open(archive, "w") as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)

if remove_files:
for file in [os.path.join(root, file) for file in files]:
if os.path.isdir(file):
dir_util.remove_tree(file)
else:
os.remove(file)

return name, compute_md5(archive)

def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []

meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, "_CATEGORIES_META", meta)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])

meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, "_FILE_LIST_META", meta)
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])

meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, "_DEVKIT_META", meta)
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)

def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]
Expand All @@ -493,20 +509,97 @@ def make_images_archive(stack, root, split, small):
make_image(os.path.join(root, folder_default, image), image_size)

meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, "_IMAGES_META", meta)
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)

return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())

data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)

if extract_images:
data["imgs"] = make_images_archive(stack, root, split, small)
else:
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
data["imgs"] = None
clean_dir(root, ".tar$")

yield root, data


@contextlib.contextmanager
def stl10_root(_extracted=False):
CLASS_NAMES = ("airplane", "bird")
ARCHIVE_NAME = "stl10_binary"
NUM_FOLDS = 10

def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
return f"{partial}.{attr}"

def make_binary_file(num_elements, root, name):
file = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file)
return name, compute_md5(file)

def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
return make_binary_file(num_images * num_channels * height * width, root, name)

def make_label_file(num_images, root, name):
return make_binary_file(num_images, root, name)

def make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for name in CLASS_NAMES:
fh.write(f"{name}\n")

def make_fold_indices_file(root):
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(NUM_FOLDS):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1

return tuple(range(1, NUM_FOLDS + 1))

def make_train_files(stack, root, num_unlabeled_images=1):
num_images_in_fold = make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)

train_list = [
list(make_image_file(num_train_images, root, "train_X.bin")),
list(make_label_file(num_train_images, root, "train_y.bin")),
list(make_image_file(1, root, "unlabeled_X.bin"))
]
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)

return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)

def make_test_files(stack, root, num_images=2):
test_list = [
list(make_image_file(num_images, root, "test_X.bin")),
list(make_label_file(num_images, root, "test_y.bin")),
]
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)

return dict(test=num_images)

def make_archive(stack, root, name):
archive, md5 = make_tar(root, name, name, compression="gz")
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
return archive

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
archive_folder = os.path.join(root, ARCHIVE_NAME)
os.mkdir(archive_folder)

num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
num_images_in_split.update(make_test_files(stack, archive_folder))

make_class_names_file(archive_folder)

archive = make_archive(stack, root, ARCHIVE_NAME)

dir_util.remove_tree(archive_folder)
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)

yield root, data
84 changes: 80 additions & 4 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import os
import unittest
Expand All @@ -7,9 +8,10 @@
from PIL import Image
from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand All @@ -28,7 +30,7 @@
HAS_PYAV = False


class Tester(unittest.TestCase):
class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
Expand All @@ -41,6 +43,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))


class Tester(DatasetTestcase):
def test_imagefolder(self):
# TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2(
Expand Down Expand Up @@ -354,7 +358,7 @@ def test_places365_devkit_download(self):
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365:
with places365_root(split=split) as places365:
root, data = places365

with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -383,12 +387,84 @@ def test_places365_images_download_preexisting(self):
torchvision.datasets.Places365(root, split=split, small=small, download=True)

def test_places365_repr_smoke(self):
with places365_root(extract_images=False) as places365:
with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)


class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data

@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data

def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass

def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)

def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)

def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass

def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass

def test_transforms(self):
expected_image = "image"
expected_target = "target"

def transform(image):
return expected_image

def target_transform(target):
return expected_target

with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]

self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)

def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))

@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()

def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1dbc6f9

Please sign in to comment.