Skip to content

Commit

Permalink
add tests for (Dataset|Image)Folder (pytorch#3477)
Browse files Browse the repository at this point in the history
* add tests for (Dataset|Image)Folder

* lint

* remove old tests

* cleanup

* more cleanup

* adapt tests

* fix make_dataset

* remove powerset

* readd import
  • Loading branch information
pmeier authored Mar 30, 2021
1 parent 7cc941f commit 20a771e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 62 deletions.
151 changes: 90 additions & 61 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):


class Tester(DatasetTestcase):
def test_imagefolder(self):
# TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [
os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
]
class_b_image_files = [
os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
]
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

# test if all classes are present
self.assertEqual(classes, sorted(dataset.classes))

# test if combination of classes and class_to_index functions correctly
for cls in classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

# test if all images were detected correctly
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)

# test if the datasets outputs all images correctly
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)

# redo all tests with specified valid image files
dataset = torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
self.assertEqual(classes, sorted(dataset.classes))

class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
if '3' in img_file]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
if '3' in img_file]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)

outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)

def test_imagefolder_empty(self):
with get_tmp_dir() as root:
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(root, loader=lambda x: x)

with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False
)

@mock.patch('torchvision.datasets.SVHN._check_integrity')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_svhn(self, mock_check):
Expand Down Expand Up @@ -1673,5 +1612,95 @@ def test_num_examples_test50k(self):
self.assertEqual(len(dataset), info["num_examples"] - 10000)


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder

# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
FEATURE_TYPES = (str, int)

_IMAGE_EXTENSIONS = ("jpg", "png")
_VIDEO_EXTENSIONS = ("avi", "mp4")
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)

# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
# 'test_is_valid_file()' method.
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
dict(extensions=_IMAGE_EXTENSIONS),
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
dict(extensions=_VIDEO_EXTENSIONS),
)

def dataset_args(self, tmpdir, config):
return tmpdir, lambda x: x

def inject_fake_data(self, tmpdir, config):
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])

num_examples_total = 0
classes = []
for ext, cls in zip(self._EXTENSIONS, string.ascii_letters):
if ext not in extensions:
continue

create_example_folder = (
datasets_utils.create_image_folder
if ext in self._IMAGE_EXTENSIONS
else datasets_utils.create_video_folder
)

num_examples = torch.randint(1, 3, size=()).item()
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)

num_examples_total += num_examples
classes.append(cls)

return dict(num_examples=num_examples_total, classes=classes)

def _file_name_fn(self, cls, ext, idx):
return f"{cls}_{idx}.{ext}"

def _is_valid_file_to_extensions(self, is_valid_file):
return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")}

@datasets_utils.test_all_configs
def test_is_valid_file(self, config):
extensions = config.pop("extensions")
# We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the
# DEFAULT_CONFIG.
with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])

@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])


class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageFolder

def inject_fake_data(self, tmpdir, config):
num_examples_total = 0
classes = ("a", "b")
for cls in classes:
num_examples = torch.randint(1, 3, size=()).item()
num_examples_total += num_examples

datasets_utils.create_image_folder(tmpdir, cls, lambda idx: f"{cls}_{idx}.png", num_examples)

return dict(num_examples=num_examples_total, classes=classes)

@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def is_valid_file(x: str) -> bool:
if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = available_classes - set(class_to_idx.keys())
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
Expand Down

0 comments on commit 20a771e

Please sign in to comment.