diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 5789c8620dc..337a7382366 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -35,7 +35,7 @@ ] -class UsageError(RuntimeError): +class UsageError(Exception): """Should be raised in case an error happens in the setup rather than the test.""" @@ -165,7 +165,8 @@ class DatasetTestCase(unittest.TestCase): Without further configuration, the testcase will test if - 1. the dataset raises a ``RuntimeError`` if the data files are not found, + 1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or + corrupted, 2. the dataset inherits from `torchvision.datasets.VisionDataset`, 3. the dataset can be turned into a string, 4. the feature types of a returned example matches ``FEATURE_TYPES``, @@ -228,9 +229,25 @@ def test_baz(self): "download_and_extract_archive", } - def inject_fake_data( - self, tmpdir: str, config: Dict[str, Any] - ) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]: + def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]: + """Define positional arguments passed to the dataset. + + .. note:: + + The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter. + Otherwise you need to overwrite this method. + + Args: + tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset + to be created and in turn also for the fake data injected here. + config (Dict[str, Any]): Configuration that will be used to create the dataset. + + Returns: + (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets. + """ + return (tmpdir,) + + def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]: """Inject fake data for dataset into a temporary directory. Args: @@ -240,15 +257,9 @@ def inject_fake_data( Needs to return one of the following: - 1. (int): Number of examples in the dataset to be created, + 1. (int): Number of examples in the dataset to be created, or 2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field - ``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or - 3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to - the dataset constructor. The second element corresponds to cases 1. and 2. - - If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset - constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default - values you need to explicitly pass them as explained in case 3. + ``"num_examples"`` that corresponds to the number of examples in the dataset to be created. """ raise NotImplementedError("You need to provide fake data in order for the tests to run.") @@ -287,33 +298,30 @@ def create_dataset( disable_download_extract = inject_fake_data with get_tmp_dir() as tmpdir: - output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None - if output is None: - raise UsageError( - "The method 'inject_fake_data' needs to return at least an integer indicating the number of " - "examples for the current configuration." - ) - - if isinstance(output, collections.abc.Sequence) and len(output) == 2: - args, info = output - else: - args = (tmpdir,) - info = output + args = self.dataset_args(tmpdir, config) - if isinstance(info, int): - info = dict(num_examples=info) - elif isinstance(info, dict): - if "num_examples" not in info: + if inject_fake_data: + info = self.inject_fake_data(tmpdir, config) + if info is None: + raise UsageError( + "The method 'inject_fake_data' needs to return at least an integer indicating the number of " + "examples for the current configuration." + ) + elif isinstance(info, int): + info = dict(num_examples=info) + elif not isinstance(info, dict): + raise UsageError( + f"The additional information returned by the method 'inject_fake_data' must be either an " + f"integer indicating the number of examples for the current configuration or a dictionary with " + f"the same content. Got {type(info)} instead." + ) + elif "num_examples" not in info: raise UsageError( "The information dictionary returned by the method 'inject_fake_data' must contain a " "'num_examples' field that holds the number of examples for the current configuration." ) else: - raise UsageError( - f"The additional information returned by the method 'inject_fake_data' must be either an integer " - f"indicating the number of examples for the current configuration or a dictionary with the the " - f"same content. Got {type(info)} instead." - ) + info = None cm = self._disable_download_extract if disable_download_extract else nullcontext with cm(special_kwargs), disable_console_output(): @@ -395,8 +403,8 @@ def _disable_download_extract(self, special_kwargs): if inject_download_kwarg: del special_kwargs["download"] - def test_not_found(self): - with self.assertRaises(RuntimeError): + def test_not_found_or_corrupted(self): + with self.assertRaises((FileNotFoundError, RuntimeError)): with self.create_dataset(inject_fake_data=False): pass @@ -511,26 +519,20 @@ class VideoDatasetTestCase(DatasetTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.inject_fake_data = self._set_default_frames_per_clip(self.inject_fake_data) + self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) def _set_default_frames_per_clip(self, inject_fake_data): argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) args_without_default = argspec.args[1:-len(argspec.defaults)] frames_per_clip_last = args_without_default[-1] == "frames_per_clip" - only_root_and_frames_per_clip = (len(args_without_default) == 2) and frames_per_clip_last @functools.wraps(inject_fake_data) def wrapper(tmpdir, config): - output = inject_fake_data(tmpdir, config) - if isinstance(output, collections.abc.Sequence) and len(output) == 2: - args, info = output - if frames_per_clip_last and len(args) == len(args_without_default) - 1: - args = (*args, self.DEFAULT_FRAMES_PER_CLIP) - return args, info - elif isinstance(output, (int, dict)) and only_root_and_frames_per_clip: - return (tmpdir, self.DEFAULT_FRAMES_PER_CLIP) - else: - return output + args = inject_fake_data(tmpdir, config) + if frames_per_clip_last and len(args) == len(args_without_default) - 1: + args = (*args, self.DEFAULT_FRAMES_PER_CLIP) + + return args return wrapper diff --git a/test/test_datasets.py b/test/test_datasets.py index 37651ae7614..096dff97217 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -824,33 +824,44 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): REQUIRED_PACKAGES = ("pycocotools",) + _IMAGE_FOLDER = "images" + _ANNOTATIONS_FOLDER = "annotations" + _ANNOTATIONS_FILE = "annotations.json" + + def dataset_args(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + root = tmpdir / self._IMAGE_FOLDER + annotation_file = tmpdir / self._ANNOTATIONS_FOLDER / self._ANNOTATIONS_FILE + return root, annotation_file + def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) num_images = 3 num_annotations_per_image = 2 - image_folder = tmpdir / "images" files = datasets_utils.create_image_folder( - tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images + tmpdir, name=self._IMAGE_FOLDER, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images ) - file_names = [file.relative_to(image_folder) for file in files] + file_names = [file.relative_to(tmpdir / self._IMAGE_FOLDER) for file in files] - annotation_folder = tmpdir / "annotations" + annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER os.makedirs(annotation_folder) - annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image) + info = self._create_annotation_file( + annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image + ) info["num_examples"] = num_images - return (str(image_folder), str(annotation_file)), info + return info - def _create_annotation_file(self, root, file_names, num_annotations_per_image): + def _create_annotation_file(self, root, name, file_names, num_annotations_per_image): image_ids = [int(file_name.stem) for file_name in file_names] images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)] annotations, info = self._create_annotations(image_ids, num_annotations_per_image) + self._create_json(root, name, dict(images=images, annotations=annotations)) - content = dict(images=images, annotations=annotations) - return self._create_json(root, "annotations.json", content), info + return info def _create_annotations(self, image_ids, num_annotations_per_image): annotations = datasets_utils.combinations_grid( @@ -888,18 +899,27 @@ class UCF101TestCase(datasets_utils.VideoDatasetTestCase): CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) + _VIDEO_FOLDER = "videos" + _ANNOTATIONS_FOLDER = "annotations" + + def dataset_args(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + root = tmpdir / self._VIDEO_FOLDER + annotation_path = tmpdir / self._ANNOTATIONS_FOLDER + return root, annotation_path + def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) - video_folder = tmpdir / "videos" + video_folder = tmpdir / self._VIDEO_FOLDER os.makedirs(video_folder) video_files = self._create_videos(video_folder) - annotations_folder = annotations_folder = tmpdir / "annotations" + annotations_folder = tmpdir / self._ANNOTATIONS_FOLDER os.makedirs(annotations_folder) num_examples = self._create_annotation_files(annotations_folder, video_files, config["fold"], config["train"]) - return (str(video_folder), str(annotations_folder)), num_examples + return num_examples def _create_videos(self, root, num_examples_per_class=3): def file_name_fn(cls, idx, clips_per_group=2):