Skip to content

Commit

Permalink
CelebA: track attr names, support split="all", code cleanup (#1008)
Browse files Browse the repository at this point in the history
* CelebA: track attr names, support split="all", code cleanup

* fix typo
  • Loading branch information
djsutherland authored and fmassa committed Jun 11, 2019
1 parent b5db97b commit e4e167a
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import torch
import os
import PIL
Expand All @@ -10,7 +11,7 @@ class CelebA(VisionDataset):
Args:
root (string): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test'}.
split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
Expand Down Expand Up @@ -78,32 +79,28 @@ def __init__(self, root,
split = 1
elif split.lower() == "test":
split = 2
elif split.lower() == "all":
split = None
else:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="valid" or split="test"')
raise ValueError('Wrong split entered! Please use "train", '
'"valid", "test", or "all"')

with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)

with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
mask = slice(None) if split is None else (splits[1] == split)

with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)

with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)

with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)

mask = (splits[1] == split)
self.filename = splits[mask].index.values
self.identity = torch.as_tensor(self.identity[mask].values)
self.bbox = torch.as_tensor(self.bbox[mask].values)
self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
self.attr = torch.as_tensor(self.attr[mask].values)
self.identity = torch.as_tensor(identity[mask].values)
self.bbox = torch.as_tensor(bbox[mask].values)
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
self.attr = torch.as_tensor(attr[mask].values)
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = list(attr.columns)

def _check_integrity(self):
for (_, md5, filename) in self.file_list:
Expand Down

0 comments on commit e4e167a

Please sign in to comment.