Skip to content

Commit

Permalink
Usps dataset (#961)
Browse files Browse the repository at this point in the history
* add USPS dataset

* minor fixes

* Improvements to the USPS dataset

Add it to the documentation, expose it to torchvision.datasets
and inherit from VisionDataset
  • Loading branch information
fmassa authored May 27, 2019
1 parent b45cdbf commit d4a126b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
7 changes: 7 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,10 @@ SBD
.. autoclass:: SBDataset
:members: __getitem__
:special-members:

USPS
~~~~~

.. autoclass:: USPS
:members: __getitem__
:special-members:
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .celeba import CelebA
from .sbd import SBDataset
from .vision import VisionDataset
from .usps import USPS

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
Expand All @@ -26,4 +27,5 @@
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset')
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS')
83 changes: 83 additions & 0 deletions torchvision/datasets/usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import print_function
from PIL import Image
import os
import numpy as np

from .utils import download_url
from .vision import VisionDataset


class USPS(VisionDataset):
"""`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
The data-format is : [label [index:value ]*256 \n] * num_lines, where ``label`` lies in ``[1, 10]``.
The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
and make pixel values in ``[0, 255]``.
Args:
root (string): Root directory of dataset to store``USPS`` data files.
train (bool, optional): If True, creates dataset from ``usps.bz2``,
otherwise from ``usps.t.bz2``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
split_list = {
'train': [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
"usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197'
],
'test': [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
"usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118'
],
}

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform)
split = 'train' if train else 'test'
url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename)

if download and not os.path.exists(full_path):
download_url(url, self.root, filename, md5=checksum)

import bz2
with bz2.open(full_path) as fp:
raw_data = [l.decode().split() for l in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]

self.data = imgs
self.targets = targets

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode='L')

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.data)

0 comments on commit d4a126b

Please sign in to comment.