From d1cc9e962a214665426f4930dbf87c6d1b9fd1e8 Mon Sep 17 00:00:00 2001 From: Daksh Jotwani Date: Thu, 1 Aug 2019 17:02:40 +0530 Subject: [PATCH 1/5] Add VGGFace2 dataset --- torchvision/datasets/__init__.py | 3 +- torchvision/datasets/vggface2.py | 73 ++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 torchvision/datasets/vggface2.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index db5b572a469..ba4070c2672 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -22,6 +22,7 @@ from .kinetics import Kinetics400 from .hmdb51 import HMDB51 from .ucf101 import UCF101 +from .vggface2 import VGGFace2 __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', @@ -31,4 +32,4 @@ 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset', - 'USPS', 'Kinetics400', 'HMDB51', 'UCF101') + 'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'VGGFace2') diff --git a/torchvision/datasets/vggface2.py b/torchvision/datasets/vggface2.py new file mode 100644 index 00000000000..2ff0b62fa31 --- /dev/null +++ b/torchvision/datasets/vggface2.py @@ -0,0 +1,73 @@ +import os +import csv +import torchvision.transforms.functional as F +from .folder import ImageFolder + + +class VGGFace2(ImageFolder): + '''`VGGFace2: A large scale image dataset for face recognition + `_ Dataset. + + Args: + root (string): Path to downloaded dataset. + target_type (string or list, optional): Target type for each sample, ``id`` + or ``bbox``. Can also be a list to output a tuple with all specified + target types. + The targets represent: + ``id`` (int): label/id for each person. + ``bbox`` (tuple[int]) bounding box encoded as x, y, width, height + Defaults to ``id``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.ToTensor``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + bb_target_crop (boolean, optional): Crops bounding box from image as target. + bb_landmarks_csv (string, optional): path to downloaded bb landmarks. Required + if ``bbox`` is in target_type or bb_target_crop is True. + + ''' + + def __init__(self, root, target_type='id', transform=None, + target_transform=None, bb_crop=False, bb_landmarks_csv=None): + super(VGGFace2, self).__init__(root, transform=transform, + target_transform=target_transform) + + if isinstance(target_type, list): + self.target_type = target_type + else: + self.target_type = [target_type] + + self.bb_crop = bb_crop + self.get_bbox = self.bb_crop or 'bbox' in self.target_type + + if self.get_bbox: + self.bb_data = {} + with open(bb_landmarks_csv, newline='') as csvfile: + reader = csv.reader(csvfile) + for path, x, y, w, h in reader: + self.bb_data[path] = (int(x), int(y), int(w), int(h)) + + def __getitem__(self, index): + path, label = self.samples[index] + sample = self.loader(path) + + if self.get_bbox: + bbox = self.bb_data[os.path.join(self.root, path) + '.jpg'] + + if self.bb_crop: + x, y, w, h = bbox + sample = F.crop(sample, x, y, h, w) + + target = [] + for t in self.target_type: + if t == 'id': + target.append(label) + elif t == 'bbox': + target.extend(bbox) + + if self.transform is not None: + sample = self.transform(sample) + + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target From 8cc6ad49c037fe087768861e7ca1a8745fffcdb9 Mon Sep 17 00:00:00 2001 From: Daksh Jotwani Date: Fri, 2 Aug 2019 13:32:37 +0530 Subject: [PATCH 2/5] Add bbox csv support --- torchvision/datasets/vggface2.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/vggface2.py b/torchvision/datasets/vggface2.py index 2ff0b62fa31..0263315dc4d 100644 --- a/torchvision/datasets/vggface2.py +++ b/torchvision/datasets/vggface2.py @@ -20,14 +20,14 @@ class VGGFace2(ImageFolder): transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.ToTensor``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. - bb_target_crop (boolean, optional): Crops bounding box from image as target. - bb_landmarks_csv (string, optional): path to downloaded bb landmarks. Required + bbox_crop (boolean, optional): Crops bounding box from image as target. + bbox_csv (string, optional): path to downloaded bb landmarks. Required if ``bbox`` is in target_type or bb_target_crop is True. ''' def __init__(self, root, target_type='id', transform=None, - target_transform=None, bb_crop=False, bb_landmarks_csv=None): + target_transform=None, bbox_crop=False, bbox_csv=None): super(VGGFace2, self).__init__(root, transform=transform, target_transform=target_transform) @@ -36,14 +36,20 @@ def __init__(self, root, target_type='id', transform=None, else: self.target_type = [target_type] - self.bb_crop = bb_crop - self.get_bbox = self.bb_crop or 'bbox' in self.target_type + self.bbox_crop = bbox_crop + self.get_bbox = self.bbox_crop or 'bbox' in self.target_type if self.get_bbox: + if bbox_csv is None: + raise ValueError("bbox_csv cannot be None if 'bbox' " + "in target_type or bbox_crop=True") + self.bb_data = {} - with open(bb_landmarks_csv, newline='') as csvfile: + with open(bbox_csv, newline='') as csvfile: reader = csv.reader(csvfile) + next(reader) for path, x, y, w, h in reader: + path = os.path.join(self.root, path) + '.jpg' self.bb_data[path] = (int(x), int(y), int(w), int(h)) def __getitem__(self, index): @@ -51,9 +57,9 @@ def __getitem__(self, index): sample = self.loader(path) if self.get_bbox: - bbox = self.bb_data[os.path.join(self.root, path) + '.jpg'] + bbox = self.bb_data[path] - if self.bb_crop: + if self.bbox_crop: x, y, w, h = bbox sample = F.crop(sample, x, y, h, w) @@ -70,4 +76,4 @@ def __getitem__(self, index): if self.target_transform is not None: target = self.target_transform(target) - return sample, target + return (sample, *target) From 037ef8703446c9d5ebade99e14fda43791c00f6d Mon Sep 17 00:00:00 2001 From: Daksh Jotwani Date: Fri, 2 Aug 2019 15:59:08 +0530 Subject: [PATCH 3/5] Add landmark csv support --- torchvision/datasets/vggface2.py | 58 ++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/torchvision/datasets/vggface2.py b/torchvision/datasets/vggface2.py index 0263315dc4d..d992196df4b 100644 --- a/torchvision/datasets/vggface2.py +++ b/torchvision/datasets/vggface2.py @@ -4,6 +4,31 @@ from .folder import ImageFolder +def read_bbox_csv(root, csv_path): + bb_data = {} + with open(csv_path, newline='') as csvfile: + reader = csv.reader(csvfile) + next(reader) + for path, x, y, w, h in reader: + path = os.path.join(root, path) + '.jpg' + bb_data[path] = (int(x), int(y), int(w), int(h)) + + return bb_data + + +def read_landmark_csv(root, csv_path): + landmark_data = {} + with open(csv_path, newline='') as csvfile: + reader = csv.reader(csvfile) + next(reader) + for row in reader: + path = os.path.join(root, row[0]) + '.jpg' + landmarks = tuple(float(x) for x in row[1:]) + landmark_data[path] = landmarks + + return landmark_data + + class VGGFace2(ImageFolder): '''`VGGFace2: A large scale image dataset for face recognition `_ Dataset. @@ -21,13 +46,15 @@ class VGGFace2(ImageFolder): and returns a transformed version. E.g, ``transforms.ToTensor``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. bbox_crop (boolean, optional): Crops bounding box from image as target. - bbox_csv (string, optional): path to downloaded bb landmarks. Required + bbox_csv (string, optional): path to downloaded bounding box csv. Required if ``bbox`` is in target_type or bb_target_crop is True. - + landmark_csv (string, optional): path to downloaded landmarks csv. Required + if ``landmark`` is in target_type. ''' def __init__(self, root, target_type='id', transform=None, - target_transform=None, bbox_crop=False, bbox_csv=None): + target_transform=None, bbox_crop=False, bbox_csv=None, + landmark_csv=None): super(VGGFace2, self).__init__(root, transform=transform, target_transform=target_transform) @@ -37,30 +64,25 @@ def __init__(self, root, target_type='id', transform=None, self.target_type = [target_type] self.bbox_crop = bbox_crop - self.get_bbox = self.bbox_crop or 'bbox' in self.target_type - if self.get_bbox: + if self.bbox_crop or 'bbox' in self.target_type: if bbox_csv is None: raise ValueError("bbox_csv cannot be None if 'bbox' " "in target_type or bbox_crop=True") + self.bb_data = read_bbox_csv(self.root, bbox_csv) - self.bb_data = {} - with open(bbox_csv, newline='') as csvfile: - reader = csv.reader(csvfile) - next(reader) - for path, x, y, w, h in reader: - path = os.path.join(self.root, path) + '.jpg' - self.bb_data[path] = (int(x), int(y), int(w), int(h)) + if 'landmark' in target_type: + if landmark_csv is None: + raise ValueError("bbox_csv cannot be None if 'landmark' " + "in target_type") + self.landmark_data = read_landmark_csv(self.root, landmark_csv) def __getitem__(self, index): path, label = self.samples[index] sample = self.loader(path) - if self.get_bbox: - bbox = self.bb_data[path] - if self.bbox_crop: - x, y, w, h = bbox + x, y, w, h = self.bb_data[path] sample = F.crop(sample, x, y, h, w) target = [] @@ -68,7 +90,9 @@ def __getitem__(self, index): if t == 'id': target.append(label) elif t == 'bbox': - target.extend(bbox) + target.append(self.bb_data[path]) + elif t == 'landmark': + target.append(self.landmark_data[path]) if self.transform is not None: sample = self.transform(sample) From 3e780e17c606c4ff03d1dbd901368f9cc39fd3af Mon Sep 17 00:00:00 2001 From: Daksh Jotwani Date: Fri, 2 Aug 2019 16:34:30 +0530 Subject: [PATCH 4/5] Add __len__ and extra_repr methods --- torchvision/datasets/vggface2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchvision/datasets/vggface2.py b/torchvision/datasets/vggface2.py index d992196df4b..4af3207051d 100644 --- a/torchvision/datasets/vggface2.py +++ b/torchvision/datasets/vggface2.py @@ -101,3 +101,9 @@ def __getitem__(self, index): target = self.target_transform(target) return (sample, *target) + + def __len__(self): + return len(self.samples) + + def extra_repr(self): + return 'Target type: {}'.format(self.target_type) From 6d907ba002e3c468ed9563a73cbb53102e3fe2b6 Mon Sep 17 00:00:00 2001 From: Daksh Jotwani Date: Mon, 5 Aug 2019 11:40:13 +0530 Subject: [PATCH 5/5] Remove tuple unpack for python2 --- torchvision/datasets/vggface2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/vggface2.py b/torchvision/datasets/vggface2.py index 4af3207051d..bb78c3d4bd3 100644 --- a/torchvision/datasets/vggface2.py +++ b/torchvision/datasets/vggface2.py @@ -100,7 +100,7 @@ def __getitem__(self, index): if self.target_transform is not None: target = self.target_transform(target) - return (sample, *target) + return tuple([sample] + target) def __len__(self): return len(self.samples)