Skip to content

Commit

Permalink
Revert "Revert "Ported places365 dataset's tests to the new test fram…
Browse files Browse the repository at this point in the history
…ework (#3705)" (#3718)" (#3731)

This reverts commit d419558.
  • Loading branch information
prabhat00155 authored Apr 26, 2021
1 parent 03f94a6 commit ae63bd0
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 201 deletions.
100 changes: 0 additions & 100 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,103 +208,3 @@ def _make_annotations_archive(root):
_make_annotations_archive(root_base)

yield root


@contextlib.contextmanager
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}

# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)

def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"

def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for string, idx in seq:
fh.write(f"{string} {idx}\n")
return name, compute_md5(file)

def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES_CONTENT)

def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST_CONTENT)

def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)

def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []

meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])

meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])

meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)

def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]

image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*FILE_LIST_CONTENT)
images = [file.lstrip("/").replace("/", os.sep) for file in files]
for image in images:
make_image(os.path.join(root, folder_default, image), image_size)

meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)

return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())

data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)

clean_dir(root, ".tar$")

yield root, data
192 changes: 91 additions & 101 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -41,106 +40,6 @@
HAS_PYAV = False


class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))

def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))


class Tester(DatasetTestcase):
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))

def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"

def transform(image):
return expected_image

def target_transform(target):
return expected_target

with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]

self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)

def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, download=True)

with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])

with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])

def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365

with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)

def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

assert all(os.path.exists(item[0]) for item in dataset.imgs)

def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "data_large_standard"

with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))

with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)

def test_places365_repr_smoke(self):
with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
Expand Down Expand Up @@ -1763,5 +1662,96 @@ def inject_fake_data(self, tmpdir, config):
return num_examples


class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train-standard", "train-challenge", "val"),
small=(False, True),
)
_CATEGORIES = "categories_places365.txt"
# {split: file}
_FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_val.txt",
}
# {(split, small): folder_name}
_IMAGES = {
("train-standard", False): "data_large_standard",
("train-challenge", False): "data_large_challenge",
("val", False): "val_large",
("train-standard", True): "data_256_standard",
("train-challenge", True): "data_256_challenge",
("val", True): "val_256",
}
# (class, idx)
_CATEGORIES_CONTENT = (
("/a/airfield", 0),
("/a/apartment_building/outdoor", 8),
("/b/badlands", 30),
)
# (file, idx)
_FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx)
for category, idx in _CATEGORIES_CONTENT),
)

@staticmethod
def _make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for text, idx in seq:
fh.write(f"{text} {idx}\n")

@staticmethod
def _make_categories_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)

@staticmethod
def _make_file_list_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)

@staticmethod
def _make_image(file_name, size):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)

@staticmethod
def _make_devkit_archive(root, split):
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])

@staticmethod
def _make_images_archive(root, split, small):
folder_name = Places365TestCase._IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
images = [f.lstrip("/").replace("/", os.sep) for f in files]
for image in images:
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)

return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]

def inject_fake_data(self, tmpdir, config):
self._make_devkit_archive(tmpdir, config['split'])
return len(self._make_images_archive(tmpdir, config['split'], config['small']))

def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)

def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)

def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with self.create_dataset({'download': True}):
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit ae63bd0

Please sign in to comment.