Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNIST and FashionMNIST now have their own 'raw' and 'processed' folders #601

Merged
merged 2 commits into from
Sep 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 57 additions & 60 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from PIL import Image
import os
import os.path
import errno
import gzip
fmassa marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch
import codecs
from .utils import download_url
from .utils import download_url, makedir_exist_ok


class MNIST(data.Dataset):
Expand All @@ -32,13 +32,10 @@ class MNIST(data.Dataset):
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)}

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
Expand All @@ -57,7 +54,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file))
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

def __getitem__(self, index):
"""
Expand All @@ -84,51 +81,61 @@ def __getitem__(self, index):
def __len__(self):
return len(self.data)

@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')

@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')

@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}

def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.processed_folder, self.test_file))

@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
import gzip

if self._check_exists():
return

# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
makedir_exist_ok(self.raw_folder)
makedir_exist_ok(self.processed_folder)

# download files
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
download_url(url, root=os.path.join(self.root, self.raw_folder),
filename=filename, md5=None)
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)

# process and save as torch files
print('Processing...')

training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)

print('Done!')
Expand Down Expand Up @@ -170,7 +177,6 @@ class FashionMNIST(MNIST):
]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class_to_idx = {_class: i for i, _class in enumerate(classes)}


class EMNIST(MNIST):
Expand Down Expand Up @@ -205,64 +211,55 @@ def __init__(self, root, split, **kwargs):
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)

def _training_file(self, split):
@staticmethod
def _training_file(split):
return 'training_{}.pt'.format(split)

def _test_file(self, split):
@staticmethod
def _test_file(split):
return 'test_{}.pt'.format(split)

def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import gzip
import shutil
import zipfile

if self._check_exists():
return

# download files
try:
fmassa marked this conversation as resolved.
Show resolved Hide resolved
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
makedir_exist_ok(self.raw_folder)
makedir_exist_ok(self.processed_folder)

# download files
filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename)
download_url(self.url, root=file_path, filename=filename, md5=None)
file_path = os.path.join(self.raw_folder, filename)
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)

print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(raw_folder)
zip_f.extractall(self.raw_folder)
os.unlink(file_path)
gzip_folder = os.path.join(raw_folder, 'gzip')
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
print('Extracting ' + gzip_file)
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
out_f.write(zip_f.read())
shutil.rmtree(gzip_folder)
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))

# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
shutil.rmtree(gzip_folder)

print('Done!')

Expand Down
21 changes: 14 additions & 7 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,27 @@ def check_integrity(fpath, md5=None):
return True


def download_url(url, root, filename, md5):
from six.moves import urllib

root = os.path.expanduser(root)
fpath = os.path.join(root, filename)

def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(root)
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise


def download_url(url, root, filename, md5):
from six.moves import urllib

root = os.path.expanduser(root)
fpath = os.path.join(root, filename)

makedir_exist_ok(root)

# downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
Expand Down