diff --git a/test/test_datasets.py b/test/test_datasets.py index 5b7eabc4cb1..e8b5a3b1df6 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -512,7 +512,7 @@ def inject_fake_data(self, tmpdir, config): return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names) def _create_split_txt(self, root): - num_images_per_split = dict(train=3, valid=2, test=1) + num_images_per_split = dict(train=4, valid=3, test=2) data = [ [self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images) @@ -595,6 +595,17 @@ def test_attr_names(self): with self.create_dataset() as (dataset, info): assert tuple(dataset.attr_names) == info["attr_names"] + def test_images_names_split(self): + with self.create_dataset(split='all') as (dataset, _): + all_imgs_names = set(dataset.filename) + + merged_imgs_names = set() + for split in ["train", "valid", "test"]: + with self.create_dataset(split=split) as (dataset, _): + merged_imgs_names.update(dataset.filename) + + assert merged_imgs_names == all_imgs_names + class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.VOCSegmentation diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 56588aaef57..f2fcdb74dfe 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -99,7 +99,10 @@ def __init__( mask = slice(None) if split_ is None else (splits.data == split_).squeeze() - self.filename = splits.index + if mask == slice(None): # if split == "all" + self.filename = splits.index + else: + self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] self.identity = identity.data[mask] self.bbox = bbox.data[mask] self.landmarks_align = landmarks_align.data[mask]