-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add USPS dataset * minor fixes * Improvements to the USPS dataset Add it to the documentation, expose it to torchvision.datasets and inherit from VisionDataset
- Loading branch information
Showing
3 changed files
with
93 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from __future__ import print_function | ||
from PIL import Image | ||
import os | ||
import numpy as np | ||
|
||
from .utils import download_url | ||
from .vision import VisionDataset | ||
|
||
|
||
class USPS(VisionDataset): | ||
"""`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset. | ||
The data-format is : [label [index:value ]*256 \n] * num_lines, where ``label`` lies in ``[1, 10]``. | ||
The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]`` | ||
and make pixel values in ``[0, 255]``. | ||
Args: | ||
root (string): Root directory of dataset to store``USPS`` data files. | ||
train (bool, optional): If True, creates dataset from ``usps.bz2``, | ||
otherwise from ``usps.t.bz2``. | ||
transform (callable, optional): A function/transform that takes in an PIL image | ||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||
target_transform (callable, optional): A function/transform that takes in the | ||
target and transforms it. | ||
download (bool, optional): If true, downloads the dataset from the internet and | ||
puts it in root directory. If dataset is already downloaded, it is not | ||
downloaded again. | ||
""" | ||
split_list = { | ||
'train': [ | ||
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", | ||
"usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197' | ||
], | ||
'test': [ | ||
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2", | ||
"usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118' | ||
], | ||
} | ||
|
||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): | ||
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform) | ||
split = 'train' if train else 'test' | ||
url, filename, checksum = self.split_list[split] | ||
full_path = os.path.join(self.root, filename) | ||
|
||
if download and not os.path.exists(full_path): | ||
download_url(url, self.root, filename, md5=checksum) | ||
|
||
import bz2 | ||
with bz2.open(full_path) as fp: | ||
raw_data = [l.decode().split() for l in fp.readlines()] | ||
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] | ||
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16)) | ||
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) | ||
targets = [int(d[0]) - 1 for d in raw_data] | ||
|
||
self.data = imgs | ||
self.targets = targets | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is index of the target class. | ||
""" | ||
img, target = self.data[index], int(self.targets[index]) | ||
|
||
# doing this so that it is consistent with all other datasets | ||
# to return a PIL Image | ||
img = Image.fromarray(img, mode='L') | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.data) |