Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly fix dataset test that passes by accident #3434

Merged
merged 17 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
]


class UsageError(RuntimeError):
class UsageError(Exception):
"""Should be raised in case an error happens in the setup rather than the test."""


Expand Down Expand Up @@ -228,9 +228,25 @@ def test_baz(self):
"download_and_extract_archive",
}

def inject_fake_data(
self, tmpdir: str, config: Dict[str, Any]
) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]:
def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
"""Define positional arguments passed to the dataset.

.. note::

The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
Otherwise you need to overwrite this method.

Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.

Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
"""
return (tmpdir,)

def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
"""Inject fake data for dataset into a temporary directory.

Args:
Expand All @@ -240,15 +256,9 @@ def inject_fake_data(

Needs to return one of the following:

1. (int): Number of examples in the dataset to be created,
1. (int): Number of examples in the dataset to be created, or
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
the dataset constructor. The second element corresponds to cases 1. and 2.

If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
values you need to explicitly pass them as explained in case 3.
``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
"""
raise NotImplementedError("You need to provide fake data in order for the tests to run.")

Expand Down Expand Up @@ -287,33 +297,30 @@ def create_dataset(
disable_download_extract = inject_fake_data

with get_tmp_dir() as tmpdir:
output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None
if output is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)

if isinstance(output, collections.abc.Sequence) and len(output) == 2:
args, info = output
else:
args = (tmpdir,)
info = output
args = self.dataset_args(tmpdir, config)

if isinstance(info, int):
info = dict(num_examples=info)
elif isinstance(info, dict):
if "num_examples" not in info:
if inject_fake_data:
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an integer "
f"indicating the number of examples for the current configuration or a dictionary with the the "
f"same content. Got {type(info)} instead."
)
info = None

cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
Expand Down Expand Up @@ -511,7 +518,7 @@ class VideoDatasetTestCase(DatasetTestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inject_fake_data = self._set_default_frames_per_clip(self.inject_fake_data)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)

def _set_default_frames_per_clip(self, inject_fake_data):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
Expand All @@ -521,16 +528,11 @@ def _set_default_frames_per_clip(self, inject_fake_data):

@functools.wraps(inject_fake_data)
def wrapper(tmpdir, config):
output = inject_fake_data(tmpdir, config)
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
args, info = output
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
return args, info
elif isinstance(output, (int, dict)) and only_root_and_frames_per_clip:
return (tmpdir, self.DEFAULT_FRAMES_PER_CLIP)
else:
return output
args = inject_fake_data(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)

return args

return wrapper

Expand Down
40 changes: 30 additions & 10 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,33 +824,44 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):

REQUIRED_PACKAGES = ("pycocotools",)

_IMAGE_FOLDER = "images"
_ANNOTATIONS_FOLDER = "annotations"
_ANNOTATIONS_FILE = "annotations.json"

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._IMAGE_FOLDER
annotation_file = tmpdir / self._ANNOTATIONS_FOLDER / self._ANNOTATIONS_FILE
return root, annotation_file

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

num_images = 3
num_annotations_per_image = 2

image_folder = tmpdir / "images"
files = datasets_utils.create_image_folder(
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
tmpdir, name=self._IMAGE_FOLDER, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
)
file_names = [file.relative_to(image_folder) for file in files]
file_names = [file.relative_to(tmpdir / self._IMAGE_FOLDER) for file in files]

annotation_folder = tmpdir / "annotations"
annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER
os.makedirs(annotation_folder)
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)
info = self._create_annotation_file(
annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image
)

info["num_examples"] = num_images
return (str(image_folder), str(annotation_file)), info
return info

def _create_annotation_file(self, root, file_names, num_annotations_per_image):
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]

annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
self._create_json(root, name, dict(images=images, annotations=annotations))

content = dict(images=images, annotations=annotations)
return self._create_json(root, "annotations.json", content), info
return info

def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
Expand Down Expand Up @@ -888,6 +899,15 @@ class UCF101TestCase(datasets_utils.VideoDatasetTestCase):

CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))

_VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations"

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._VIDEO_FOLDER
annotation_path = tmpdir / self._ANNOTATIONS_FOLDER
return root, annotation_path

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

Expand All @@ -899,7 +919,7 @@ def inject_fake_data(self, tmpdir, config):
os.makedirs(annotations_folder)
num_examples = self._create_annotation_files(annotations_folder, video_files, config["fold"], config["train"])

return (str(video_folder), str(annotations_folder)), num_examples
return num_examples

def _create_videos(self, root, num_examples_per_class=3):
def file_name_fn(cls, idx, clips_per_group=2):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def __init__(
super().__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO

self.coco = COCO(annFile)
try:
self.coco = COCO(annFile)
except FileNotFoundError as error:
raise RuntimeError(f"The file {annFile} does not exist or is corrupt.") from error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be changing the call-sites so that the current test infra works. Plus, FileNotFoundError is a meaningful type of error to be raised in here. Would your plan be to change all locations in the code to always raise RuntimeError?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be changing the call-sites so that the current test infra works.

IMO every dataset should handle corrupted or non-existent files properly. I wouldn't do that to satisfy our tests, but rather because it can make the life of the user easier with a descriptive error message. Plus, this is what you agreed to

def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.create_dataset(inject_fake_data=False):
pass

Plus, FileNotFoundError is a meaningful type of error to be raised in here.

I agree. I'll update the test infrastructure.

Would your plan be to change all locations in the code to always raise RuntimeError?

That was the original plan, yes. Maybe we can differentiate between RuntimeError for corrupted and FileNotFoundError for non-existent files. If we want to make sure we could also opt to create an FileCorruptedError(Exception) and check for that instead of a more general RuntimeError.


self.ids = list(sorted(self.coco.imgs.keys()))

def _load_image(self, id: int) -> Image.Image:
Expand Down