Skip to content

Commit

Permalink
Properly fix dataset test that passes by accident (#3434)
Browse files Browse the repository at this point in the history
* make UsageError an Exception rather than RuntimeError

* separate fake data injection and dataset args handling

* adapt tests for Coco

* fix Coco implementation

* add documentation

* fix VideoDatasetTestCase

* adapt UCF101 tests

* cleanup

* allow FileNotFoundError for test without fake data

* Revert "fix Coco implementation"

This reverts commit e2b6938.

* lint

* fix UCF101 tests
  • Loading branch information
pmeier authored Feb 25, 2021
1 parent fc33c46 commit 13c4470
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 60 deletions.
98 changes: 50 additions & 48 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
]


class UsageError(RuntimeError):
class UsageError(Exception):
"""Should be raised in case an error happens in the setup rather than the test."""


Expand Down Expand Up @@ -165,7 +165,8 @@ class DatasetTestCase(unittest.TestCase):
Without further configuration, the testcase will test if
1. the dataset raises a ``RuntimeError`` if the data files are not found,
1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
corrupted,
2. the dataset inherits from `torchvision.datasets.VisionDataset`,
3. the dataset can be turned into a string,
4. the feature types of a returned example matches ``FEATURE_TYPES``,
Expand Down Expand Up @@ -228,9 +229,25 @@ def test_baz(self):
"download_and_extract_archive",
}

def inject_fake_data(
self, tmpdir: str, config: Dict[str, Any]
) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]:
def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
"""Define positional arguments passed to the dataset.
.. note::
The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
Otherwise you need to overwrite this method.
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
"""
return (tmpdir,)

def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
"""Inject fake data for dataset into a temporary directory.
Args:
Expand All @@ -240,15 +257,9 @@ def inject_fake_data(
Needs to return one of the following:
1. (int): Number of examples in the dataset to be created,
1. (int): Number of examples in the dataset to be created, or
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
the dataset constructor. The second element corresponds to cases 1. and 2.
If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
values you need to explicitly pass them as explained in case 3.
``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
"""
raise NotImplementedError("You need to provide fake data in order for the tests to run.")

Expand Down Expand Up @@ -287,33 +298,30 @@ def create_dataset(
disable_download_extract = inject_fake_data

with get_tmp_dir() as tmpdir:
output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None
if output is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)

if isinstance(output, collections.abc.Sequence) and len(output) == 2:
args, info = output
else:
args = (tmpdir,)
info = output
args = self.dataset_args(tmpdir, config)

if isinstance(info, int):
info = dict(num_examples=info)
elif isinstance(info, dict):
if "num_examples" not in info:
if inject_fake_data:
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an integer "
f"indicating the number of examples for the current configuration or a dictionary with the the "
f"same content. Got {type(info)} instead."
)
info = None

cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
Expand Down Expand Up @@ -395,8 +403,8 @@ def _disable_download_extract(self, special_kwargs):
if inject_download_kwarg:
del special_kwargs["download"]

def test_not_found(self):
with self.assertRaises(RuntimeError):
def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass

Expand Down Expand Up @@ -511,26 +519,20 @@ class VideoDatasetTestCase(DatasetTestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inject_fake_data = self._set_default_frames_per_clip(self.inject_fake_data)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)

def _set_default_frames_per_clip(self, inject_fake_data):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1:-len(argspec.defaults)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
only_root_and_frames_per_clip = (len(args_without_default) == 2) and frames_per_clip_last

@functools.wraps(inject_fake_data)
def wrapper(tmpdir, config):
output = inject_fake_data(tmpdir, config)
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
args, info = output
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
return args, info
elif isinstance(output, (int, dict)) and only_root_and_frames_per_clip:
return (tmpdir, self.DEFAULT_FRAMES_PER_CLIP)
else:
return output
args = inject_fake_data(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)

return args

return wrapper

Expand Down
44 changes: 32 additions & 12 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,33 +824,44 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):

REQUIRED_PACKAGES = ("pycocotools",)

_IMAGE_FOLDER = "images"
_ANNOTATIONS_FOLDER = "annotations"
_ANNOTATIONS_FILE = "annotations.json"

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._IMAGE_FOLDER
annotation_file = tmpdir / self._ANNOTATIONS_FOLDER / self._ANNOTATIONS_FILE
return root, annotation_file

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

num_images = 3
num_annotations_per_image = 2

image_folder = tmpdir / "images"
files = datasets_utils.create_image_folder(
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
tmpdir, name=self._IMAGE_FOLDER, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
)
file_names = [file.relative_to(image_folder) for file in files]
file_names = [file.relative_to(tmpdir / self._IMAGE_FOLDER) for file in files]

annotation_folder = tmpdir / "annotations"
annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER
os.makedirs(annotation_folder)
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)
info = self._create_annotation_file(
annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image
)

info["num_examples"] = num_images
return (str(image_folder), str(annotation_file)), info
return info

def _create_annotation_file(self, root, file_names, num_annotations_per_image):
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]

annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
self._create_json(root, name, dict(images=images, annotations=annotations))

content = dict(images=images, annotations=annotations)
return self._create_json(root, "annotations.json", content), info
return info

def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
Expand Down Expand Up @@ -888,18 +899,27 @@ class UCF101TestCase(datasets_utils.VideoDatasetTestCase):

CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))

_VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations"

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._VIDEO_FOLDER
annotation_path = tmpdir / self._ANNOTATIONS_FOLDER
return root, annotation_path

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

video_folder = tmpdir / "videos"
video_folder = tmpdir / self._VIDEO_FOLDER
os.makedirs(video_folder)
video_files = self._create_videos(video_folder)

annotations_folder = annotations_folder = tmpdir / "annotations"
annotations_folder = tmpdir / self._ANNOTATIONS_FOLDER
os.makedirs(annotations_folder)
num_examples = self._create_annotation_files(annotations_folder, video_files, config["fold"], config["train"])

return (str(video_folder), str(annotations_folder)), num_examples
return num_examples

def _create_videos(self, root, num_examples_per_class=3):
def file_name_fn(cls, idx, clips_per_group=2):
Expand Down

0 comments on commit 13c4470

Please sign in to comment.