Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Caltech101, Caltech256, and CelebA #775

Merged
merged 10 commits into from
Mar 7, 2019
5 changes: 4 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection
from .cityscapes import Cityscapes
from .caltech import Caltech101, Caltech256
from .celeba import CelebA

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes')
'VOCSegmentation', 'VOCDetection', 'Cityscapes',
'Caltech101', 'Caltech256', 'CelebA')
245 changes: 245 additions & 0 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections

import torch.utils.data as data
from .utils import download_url, check_integrity, makedir_exist_ok


class Caltech101(data.Dataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified target types.
``category`` represents the target class, and ``annotation`` is a list of points
from a hand-generated outline. Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101")
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class

# For some reason, the category names in "101_ObjectCategories" and
# "Annotations" do not always match. This is a manual map between the
# two. Defaults to using same name, since most names are fine.
name_map = {"Faces": "Faces_2",
"Faces_easy": "Faces_3",
"Motorbikes": "Motorbikes_16",
"airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where the type of target specified by target_type.
"""
import scipy.io

img = Image.open(os.path.join(self.root,
"101_ObjectCategories",
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index])))

target = []
for t in self.target_type:
if t == "category":
target.append(self.y[index])
elif t == "annotation":
data = scipy.io.loadmat(os.path.join(self.root,
"Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)

with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


class Caltech256(data.Dataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

def __init__(self, root,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256")
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
img = Image.open(os.path.join(self.root,
"256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))

target = self.y[index]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")

# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
Loading