diff --git a/test/test_datasets.py b/test/test_datasets.py index 5f4828ef589..fb069439763 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1122,5 +1122,38 @@ def _create_alphabet_folder(self, root, name): return num_images_total +class SBUTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SBU + FEATURE_TYPES = (PIL.Image.Image, str) + + def inject_fake_data(self, tmpdir, config): + num_images = 3 + + dataset_folder = pathlib.Path(tmpdir) / "dataset" + images = datasets_utils.create_image_folder(tmpdir, "dataset", self._create_file_name, num_images) + + self._create_urls_txt(dataset_folder, images) + self._create_captions_txt(dataset_folder, num_images) + + return num_images + + def _create_file_name(self, idx): + part1 = datasets_utils.create_random_string(10, string.digits) + part2 = datasets_utils.create_random_string(10, string.ascii_lowercase, string.digits[:6]) + return f"{part1}_{part2}.jpg" + + def _create_urls_txt(self, root, images): + with open(root / "SBU_captioned_photo_dataset_urls.txt", "w") as fh: + for image in images: + fh.write( + f"http://static.flickr.com/{datasets_utils.create_random_string(4, string.digits)}/{image.name}\n" + ) + + def _create_captions_txt(self, root, num_images): + with open(root / "SBU_captioned_photo_dataset_captions.txt", "w") as fh: + for _ in range(num_images): + fh.write(f"{datasets_utils.create_random_string(10)}\n") + + if __name__ == "__main__": unittest.main()