diff --git a/test/test_datasets.py b/test/test_datasets.py index bea2a2b80b9..33f59f6a763 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -602,6 +602,146 @@ def test_attr_names(self): self.assertEqual(tuple(dataset.attr_names), info["attr_names"]) +class Cub2011TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Cub2011 + FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None))) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test", "all"), + target_type=(["class_label"], ["segmentation"], ["bbox"]), + ) + + _SPLIT_TO_IDX = dict(train=1, test=0) + + def inject_fake_data(self, tmpdir, config): + base_folder = [pathlib.Path(tmpdir) / "Cub2011"] + base_folder.append(pathlib.Path(tmpdir) / "Cub2011" / "segmentations") + base_folder.append(pathlib.Path(tmpdir) / "Cub2011" / "CUB_200_2011") + base_folder.append(pathlib.Path(tmpdir) / "Cub2011" / "CUB_200_2011") + base_folder.append(pathlib.Path(tmpdir) / "Cub2011" / "CUB_200_2011" / "images") + for folder in base_folder: + if not os.path.exists(folder): + os.makedirs(folder) + + num_images, num_images_per_split = self._create_split_txt(base_folder) + target_class = '001.Black_footed_Albatross' + datasets_utils.create_image_folder( + base_folder[-1], target_class, lambda idx: f"{idx + 1:06d}.jpg", num_images + ) + datasets_utils.create_image_folder( + base_folder[1], target_class, lambda idx: f"{idx + 1:06d}.png", num_images + ) + self._create_bbox_txt(base_folder, num_images) + + self._create_class_lables_txt(base_folder, num_images) + self._create_images_txt(base_folder, target_class, num_images) + + return dict(num_examples=num_images_per_split[config["split"]]) + + def _create_split_txt(self, root): + num_images_per_split = dict(train=5, test=2) + + indx = 0 + data = [] + for split, num_images in num_images_per_split.items(): + for _ in range(num_images): + data.append([indx, self._SPLIT_TO_IDX[split]]) + indx += 1 + self._create_txt(root, "train_test_split.txt", data) + + num_images_per_split["all"] = num_images = sum(num_images_per_split.values()) + return num_images, num_images_per_split + # + # def _create_attr_txt(self, root, num_images): + # header = ("5_o_Clock_Shadow", "Young") + # data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist() + # self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True) + # return header + + def _create_class_lables_txt(self, root, num_images): + data = [] + for ind in range(num_images): + data.append([ind, 1]) + self._create_txt(root, "image_class_labels.txt", data) + + def _create_bbox_txt(self, root, num_images): + header = ("x_1", "y_1", "width", "height") + data = torch.randint(10, size=(num_images, len(header))).float().tolist() + data_with_index = [] + for indx, data_i in enumerate(data): + data_with_index.append([indx, data_i]) + self._create_txt( + root, "bounding_boxes.txt", data_with_index) + + def _create_images_txt(self, root, target_class, num_images): + data = [] + for ind in range(num_images): + data.append([ind, f'{target_class}/{ind + 1:06d}.jpg']) + self._create_txt(root, "images.txt", data) + + def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False): + with open(pathlib.Path(root[-2]) / name, "w") as fh: + if add_num_examples: + fh.write(f"{len(data)}\n") + + if header: + if add_image_id_to_header: + header = ("image_id", *header) + fh.write(f"{' '.join(header)}\n") + + for idx, line in enumerate(data, 1): + meta_data = '' + if isinstance(line[1], list): + for lin in line[1]: + meta_data = meta_data + ' ' + str(lin) + meta_data = meta_data.lstrip() + else: + meta_data = line[1] + fh.write(f"{line[0]} {meta_data}\n") + + def test_combined_targets(self): + target_types = ["class_label", "segmentation", "bbox"] + + individual_targets = [] + for target_type in target_types: + with self.create_dataset(target_type=[target_type]) as (dataset, _): + _, target = dataset[0] + individual_targets.append(target) + + with self.create_dataset(target_type=target_types) as (dataset, _): + _, combined_targets = dataset[0] + + actual = len(individual_targets) + expected = len(combined_targets) + self.assertEqual( + actual, + expected, + f"The number of the returned combined targets does not match the the number targets if requested " + f"individually: {actual} != {expected}", + ) + + for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets): + with self.subTest(target_type=target_type): + actual = type(combined_target) + expected = type(individual_target) + self.assertIs( + actual, + expected, + f"Type of the combined target does not match the type of the corresponding individual target: " + f"{actual} is not {expected}", + ) + + def test_no_target(self): + with self.create_dataset(target_type=[]) as (dataset, _): + _, target = dataset[0] + + self.assertIsNone(target) + + def test_dataset_length(self): + with self.create_dataset() as (dataset, info): + self.assertEqual(len(dataset.index_list), info["num_examples"]) + + class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.VOCSegmentation FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e67ba08d299..cedcb770154 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -25,6 +25,7 @@ from .ucf101 import UCF101 from .places365 import Places365 from .kitti import Kitti +from .cub2011 import Cub2011 __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', @@ -35,5 +36,5 @@ 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', 'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101', - 'Places365', 'Kitti', + 'Places365', 'Kitti', 'Cub2011' ) diff --git a/torchvision/datasets/cub2011.py b/torchvision/datasets/cub2011.py new file mode 100644 index 00000000000..53f144761be --- /dev/null +++ b/torchvision/datasets/cub2011.py @@ -0,0 +1,148 @@ +from collections import namedtuple +import csv +from functools import partial +import torch +import os +import PIL +from typing import Any, Callable, List, Optional, Union, Tuple +from .vision import VisionDataset +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive +import numpy as np + + +class Cub2011(VisionDataset): + """`CUB `_ Dataset. + Args: + root (string): Root directory where images are downloaded to. + split (string): One of {'train', 'test', 'all'}. + Accordingly dataset is selected. + target_type (string or list, optional):List of target to use, ``class_label``, ``segmentation`` or ``bbox``. + Can also be a list to output a tuple with all specified target types. + The targets represent: + - ``class_label`` (int): range (0-200) labels for attributes + - ``segmentation`` (float): segmentation map of each input Image + - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) + Defaults to ``class_label``. If empty, ``None`` will be returned as target. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + base_folder = "Cub2011" + file_list = [ + # File ID MD5 Hash Filename + ("1hbzc_P1FuxMkcabkgn9ZKinBwW683j45", "97eceeb196236b17998738112f37df78", "CUB_200_2011.tgz"), + ("1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP", "4d47ba1228eae64f2fa547c47bc65255", "segmentations.tgz"), + ] + + meta_data = {"image_lst": "images.txt", "class_labels": "image_class_labels.txt", + "split_lst": "train_test_split.txt", "bb_lst": "bounding_boxes.txt"} + + def __init__( + self, + root: str, + split: str = "train", + target_type: Union[List[str], str] = ["class_label"], + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform) + + self.target_type = target_type + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.root = root + split_map = { + "test": 0, + "train": 1, + "all": 2, + } + + self._meta_data = {} + for key in self.meta_data.keys(): + self._meta_data[key] = self.filter_data(key) + + self.index_list = self.process_indexs(self._meta_data['split_lst'], split_map[split]) + + print('Dataset Loaded Successfully') + + def _check_integrity(self) -> bool: + for (_, md5, filename) in self.file_list: + fpath = os.path.join(self.root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self) -> None: + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + for (file_id, md5, filename) in self.file_list: + download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) + extract_archive(os.path.join(self.root, self.base_folder, filename)) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + filename = self._meta_data["image_lst"][self.index_list[index]] + X = PIL.Image.open(os.path.join(self.root, self.base_folder, "CUB_200_2011", 'images', filename)).convert('RGB') + + target: Any = [] + for t in self.target_type: + if t == "class_label": + target.append(torch.tensor(float(self._meta_data["class_labels"][index]))) + elif t == "segmentation": + segmentaiton = PIL.Image.open(os.path.join(self.root, self.base_folder, "segmentations", + filename.replace('.jpg', '.png'))).convert('1') + target.append(torch.tensor(np.asarray(segmentaiton), dtype=torch.float32)) + elif t == "bbox": + target.append(torch.tensor(np.asarray(self._meta_data["bb_lst"][index]).astype(np.float))) + else: + raise ValueError("Target type \"{}\" is not recognized.".format(t)) + + if self.transform is not None: + X = self.transform(X) + + if target: + target = tuple(target) if len(target) > 1 else target[0] + + if self.target_transform is not None: + target = self.target_transform(target) + else: + target = None + + return X, target + + def __len__(self) -> int: + return len(self.index_list) + + def process_indexs(self, list, target): + processed_index = [] + for i in range(len(list)): + if target == 2: + processed_index.append(i) + elif int(list[i]) == target: + processed_index.append(i) + return processed_index + + def filter_data(self, key): + filter_data_lst = [] + for ind, data in enumerate(open(os.path.join(self.root, self.base_folder, "CUB_200_2011", + self.meta_data[key]), 'r').readlines()): + data = data[:-1].split(' ') + if len(data) == 2: + filter_data_lst.append(data[1]) + else: + filter_data_lst.append(data[1:]) + return filter_data_lst