Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LFW Dataset #4255

Merged
merged 31 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
959880f
Added LFW Dataset
ABD-01 Aug 5, 2021
5f4503e
Merge branch 'master' into add_lfw
ABD-01 Aug 6, 2021
4b012c5
Merge branch 'master' into add_lfw
ABD-01 Aug 12, 2021
bd414d4
Merge branch 'pytorch:main' into add_lfw
ABD-01 Aug 23, 2021
ac4b4ad
Added dataset to list in __init__.py
Aug 23, 2021
cfef8c8
Updated lfw.py
Aug 23, 2021
66f43c9
Merge branch 'pytorch:main' into add_lfw
ABD-01 Aug 26, 2021
c7bf4ae
Added docstrings and updated datasets.rst
ABD-01 Aug 26, 2021
749308a
Wrote tests for LFWPeople and LFWPairs
ABD-01 Aug 28, 2021
3f4f214
Merge branch 'main' into add_lfw
ABD-01 Aug 28, 2021
6590da3
Resolved mypy error: Need type annotation for "data"
ABD-01 Aug 28, 2021
acb68ae
Updated inject_fake_data method for LFWPeople
ABD-01 Aug 28, 2021
10fffb0
Updated tests for LFW
ABD-01 Aug 29, 2021
0434472
Merge branch 'pytorch:main' into add_lfw
ABD-01 Aug 31, 2021
87197f5
Updated LFW tests and minor changes in lfw.py
ABD-01 Aug 31, 2021
b34a173
Merge branch 'main' into add_lfw
ABD-01 Sep 2, 2021
7549517
Updated LFW
ABD-01 Sep 4, 2021
c52c891
Updated lfw.py and tests
ABD-01 Sep 9, 2021
a5146eb
Merge branch 'main' into add_lfw
ABD-01 Sep 9, 2021
df96b44
resolved py lint errors
ABD-01 Sep 9, 2021
e93cd21
Merge branch 'add_lfw' of github.com:ABD-01/vision into add_lfw
ABD-01 Sep 9, 2021
baf5556
Merge branch 'main' into add_lfw
ABD-01 Sep 9, 2021
4feed66
Added checksums for annotation files
ABD-01 Sep 9, 2021
e9cb48e
Minor changes in test
ABD-01 Sep 9, 2021
cc475cc
Updated docstrings, defaults and minor changes in test
ABD-01 Sep 10, 2021
1466334
Removed 'os.path.exists' check
ABD-01 Sep 10, 2021
f5c41c4
Merge branch 'main' into add_lfw
ABD-01 Sep 10, 2021
69e8f2f
Merge branch 'main' into add_lfw
ABD-01 Sep 10, 2021
9860672
Merge branch 'main' into add_lfw
pmeier Sep 13, 2021
3e17463
Merge branch 'main' into add_lfw
ABD-01 Sep 13, 2021
4c4b826
Merge branch 'main' into add_lfw
fmassa Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ KMNIST

.. autoclass:: KMNIST

LFW
~~~~~

.. autoclass:: LFWPeople
:members: __getitem__
:special-members:

.. autoclass:: LFWPairs
:members: __getitem__
:special-members:

LSUN
~~~~

Expand Down
82 changes: 82 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,5 +1801,87 @@ def test_targets(self):
assert item[6] == i // 3


class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=('10fold', 'train', 'test'),
image_set=('original', 'funneled', 'deepfunneled')
)
_IMAGES_DIR = {
"original": "lfw",
"funneled": "lfw_funneled",
"deepfunneled": "lfw-deepfunneled"
}
_split = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
ABD-01 marked this conversation as resolved.
Show resolved Hide resolved

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
os.makedirs(tmpdir, exist_ok=True)
return dict(
num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]),
split=config["split"]
)

def _create_images_dir(self, root, idir, split):
idir = os.path.join(root, idir)
os.makedirs(idir, exist_ok=True)
n, flines = (10, ["10\n"]) if split == "10fold" else (1, [])
num_examples = 0
names = []
for _ in range(n):
num_people = random.randint(2, 5)
flines.append(f"{num_people}\n")
for i in range(num_people):
name = self._create_random_id()
no = random.randint(1, 10)
flines.append(f"{name}\t{no}\n")
names.append(f"{name}\t{no}\n")
datasets_utils.create_image_folder(idir, name, lambda n: f"{name}_{n+1:04d}.jpg", no, 250)
num_examples += no
with open(pathlib.Path(root) / f"people{self._split[split]}.txt", "w") as f:
f.writelines(flines)
with open(pathlib.Path(root) / "lfw-names.txt", "w") as f:
f.writelines(sorted(names))

return num_examples

def _create_random_id(self):
part1 = datasets_utils.create_random_string(random.randint(5, 7))
part2 = datasets_utils.create_random_string(random.randint(4, 7))
return f"{part1}_{part2}"


class LFWPairsTestCase(LFWPeopleTestCase):
DATASET_CLASS = datasets.LFWPairs
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, int)

def _create_images_dir(self, root, idir, split):
idir = os.path.join(root, idir)
os.makedirs(idir, exist_ok=True)
num_pairs = 7 # effectively 7*2*n = 14*n
n, self.flines = (10, [f"10\t{num_pairs}"]) if split == "10fold" else (1, [str(num_pairs)])
for _ in range(n):
self._inject_pairs(idir, num_pairs, True)
self._inject_pairs(idir, num_pairs, False)
with open(pathlib.Path(root) / f"pairs{self._split[split]}.txt", "w") as f:
f.writelines(self.flines)

return num_pairs * 2 * n

def _inject_pairs(self, root, num_pairs, same):
for i in range(num_pairs):
name1 = self._create_random_id()
name2 = name1 if same else self._create_random_id()
no1, no2 = random.randint(1, 100), random.randint(1, 100)
if same:
self.flines.append(f"\n{name1}\t{no1}\t{no2}")
else:
self.flines.append(f"\n{name1}\t{no1}\t{name2}\t{no2}")

datasets_utils.create_image_folder(root, name1, lambda _: f"{name1}_{no1:04d}.jpg", 1, 250)
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .places365 import Places365
from .kitti import Kitti
from .inaturalist import INaturalist
from .lfw import LFWPeople, LFWPairs

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
Expand All @@ -36,5 +37,5 @@
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
'Places365', 'Kitti', "INaturalist"
'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
)
260 changes: 260 additions & 0 deletions torchvision/datasets/lfw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import os
from typing import Any, Callable, List, Optional, Tuple
from PIL import Image
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg


class _LFW(VisionDataset):

base_folder = 'lfw-py'
download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"

file_dict = {
'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201")
}
checksums = {
'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d',
'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b',
'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21',
'people.txt': '450f0863dd89e85e73936a6d71a3474b',
'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5',
'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21',
'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d'
}
annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
names = "lfw-names.txt"

def __init__(
self,
root: str,
split: str,
image_set: str,
view: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(_LFW, self).__init__(os.path.join(root, self.base_folder),
transform=transform, target_transform=target_transform)

self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set]

self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs'])
self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test'])
self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
self.data: List[Any] = []

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.images_dir = os.path.join(self.root, images_dir)

def _loader(self, path: str) -> Image.Image:
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

def _check_integrity(self):
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
if not st1 or not st2:
return False
if self.view == "people":
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
return True

def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
url = f"{self.download_url_prefix}{self.filename}"
download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
if self.view == "people":
download_url(f"{self.download_url_prefix}{self.names}", self.root)

def _get_path(self, identity, no):
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")

def extra_repr(self) -> str:
return f"Alignment: {self.image_set}\nSplit: {self.split}"

def __len__(self):
return len(self.data)


class LFWPeople(_LFW):
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``lfw-py`` exists or will be saved to if download is set to True.
split (string, optional): The image split to use. Can be one of ``train`` (default), ``test``,
``10fold``.
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
``deepfunneled``. Defaults to ``original``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomRotation``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""

def __init__(
self,
root: str,
split: str = "train",
image_set: str = "funneled",
ABD-01 marked this conversation as resolved.
Show resolved Hide resolved
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPeople, self).__init__(root, split, image_set, "people",
transform, target_transform, download)

self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()

def _get_people(self):
data, targets = [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)

for fold in range(n_folds):
n_lines = int(lines[s])
people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]]
s += n_lines + 1
for i, (identity, num_imgs) in enumerate(people):
for num in range(1, int(num_imgs) + 1):
img = self._get_path(identity, num)
if os.path.exists(img):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
data.append(img)
targets.append(self.class_to_idx[identity])

return data, targets

def _get_classes(self):
with open(os.path.join(self.root, self.names), 'r') as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)}
return class_to_idx

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index

Returns:
tuple: Tuple (image, target) where target is the identity of the person.
"""
img = self._loader(self.data[index])
target = self.targets[index]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def extra_repr(self) -> str:
return super().extra_repr() + "\nClasses (identities): {}".format(len(self.class_to_idx))


class LFWPairs(_LFW):
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``lfw-py`` exists or will be saved to if download is set to True.
split (string, optional): The image split to use. Can be one of ``train``, ``test``,
``10fold``. Defaults to ``10fold``.
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
``deepfunneled``. Defaults to ``original``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomRotation``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""

def __init__(
self,
root: str,
split: str = "10fold",
image_set: str = "funneled",
ABD-01 marked this conversation as resolved.
Show resolved Hide resolved
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPairs, self).__init__(root, split, image_set, "pairs",
transform, target_transform, download)

self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)

def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
lines = f.readlines()
if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t")
n_folds, n_pairs = int(n_folds), int(n_pairs)
else:
n_folds, n_pairs = 1, int(lines[0])
s = 1

for fold in range(n_folds):
matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]]
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]]
s += (2 * n_pairs)
for pair in matched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
if os.path.exists(img1) and os.path.exists(img2):
ABD-01 marked this conversation as resolved.
Show resolved Hide resolved
pair_names.append((pair[0], pair[0]))
data.append((img1, img2))
targets.append(same)
for pair in unmatched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
if os.path.exists(img1) and os.path.exists(img2):
ABD-01 marked this conversation as resolved.
Show resolved Hide resolved
pair_names.append((pair[0], pair[2]))
data.append((img1, img2))
targets.append(same)

return pair_names, data, targets

def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
"""
Args:
index (int): Index

Returns:
tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
"""
img1, img2 = self.data[index]
img1, img2 = self._loader(img1), self._loader(img2)
target = self.targets[index]

if self.transform is not None:
img1, img2 = self.transform(img1), self.transform(img2)

if self.target_transform is not None:
target = self.target_transform(target)

return img1, img2, target