diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index ee4f30d4f18..19a5c6d3b20 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -3,11 +3,11 @@ from PIL import Image import os import os.path -import errno +import gzip import numpy as np import torch import codecs -from .utils import download_url +from .utils import download_url, makedir_exist_ok class MNIST(data.Dataset): @@ -32,13 +32,10 @@ class MNIST(data.Dataset): 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', ] - raw_folder = 'raw' - processed_folder = 'processed' training_file = 'training.pt' test_file = 'test.pt' classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] - class_to_idx = {_class: i for i, _class in enumerate(classes)} def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) @@ -57,7 +54,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down data_file = self.training_file else: data_file = self.test_file - self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file)) + self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) def __getitem__(self, index): """ @@ -84,51 +81,61 @@ def __getitem__(self, index): def __len__(self): return len(self.data) + @property + def raw_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'processed') + + @property + def class_to_idx(self): + return {_class: i for i, _class in enumerate(self.classes)} + def _check_exists(self): - return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ - os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) + return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \ + os.path.exists(os.path.join(self.processed_folder, self.test_file)) + + @staticmethod + def extract_gzip(gzip_path, remove_finished=False): + print('Extracting {}'.format(gzip_path)) + with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \ + gzip.GzipFile(gzip_path) as zip_f: + out_f.write(zip_f.read()) + if remove_finished: + os.unlink(gzip_path) def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" - import gzip if self._check_exists(): return - # download files - try: - os.makedirs(os.path.join(self.root, self.raw_folder)) - os.makedirs(os.path.join(self.root, self.processed_folder)) - except OSError as e: - if e.errno == errno.EEXIST: - pass - else: - raise + makedir_exist_ok(self.raw_folder) + makedir_exist_ok(self.processed_folder) + # download files for url in self.urls: filename = url.rpartition('/')[2] - file_path = os.path.join(self.root, self.raw_folder, filename) - download_url(url, root=os.path.join(self.root, self.raw_folder), - filename=filename, md5=None) - with open(file_path.replace('.gz', ''), 'wb') as out_f, \ - gzip.GzipFile(file_path) as zip_f: - out_f.write(zip_f.read()) - os.unlink(file_path) + file_path = os.path.join(self.raw_folder, filename) + download_url(url, root=self.raw_folder, filename=filename, md5=None) + self.extract_gzip(gzip_path=file_path, remove_finished=True) # process and save as torch files print('Processing...') training_set = ( - read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), - read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) + read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) ) test_set = ( - read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), - read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) + read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) ) - with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: + with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: torch.save(training_set, f) - with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: + with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: torch.save(test_set, f) print('Done!') @@ -170,7 +177,6 @@ class FashionMNIST(MNIST): ] classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] - class_to_idx = {_class: i for i, _class in enumerate(classes)} class EMNIST(MNIST): @@ -205,64 +211,55 @@ def __init__(self, root, split, **kwargs): self.test_file = self._test_file(split) super(EMNIST, self).__init__(root, **kwargs) - def _training_file(self, split): + @staticmethod + def _training_file(split): return 'training_{}.pt'.format(split) - def _test_file(self, split): + @staticmethod + def _test_file(split): return 'test_{}.pt'.format(split) def download(self): """Download the EMNIST data if it doesn't exist in processed_folder already.""" - import gzip import shutil import zipfile if self._check_exists(): return - # download files - try: - os.makedirs(os.path.join(self.root, self.raw_folder)) - os.makedirs(os.path.join(self.root, self.processed_folder)) - except OSError as e: - if e.errno == errno.EEXIST: - pass - else: - raise + makedir_exist_ok(self.raw_folder) + makedir_exist_ok(self.processed_folder) + # download files filename = self.url.rpartition('/')[2] - raw_folder = os.path.join(self.root, self.raw_folder) - file_path = os.path.join(raw_folder, filename) - download_url(self.url, root=file_path, filename=filename, md5=None) + file_path = os.path.join(self.raw_folder, filename) + download_url(self.url, root=self.raw_folder, filename=filename, md5=None) print('Extracting zip archive') with zipfile.ZipFile(file_path) as zip_f: - zip_f.extractall(raw_folder) + zip_f.extractall(self.raw_folder) os.unlink(file_path) - gzip_folder = os.path.join(raw_folder, 'gzip') + gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - print('Extracting ' + gzip_file) - with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \ - gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f: - out_f.write(zip_f.read()) - shutil.rmtree(gzip_folder) + self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) # process and save as torch files for split in self.splits: print('Processing ' + split) training_set = ( - read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) + read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), + read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) ) test_set = ( - read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) + read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), + read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) ) - with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f: + with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f: torch.save(training_set, f) - with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f: + with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f: torch.save(test_set, f) + shutil.rmtree(gzip_folder) print('Done!') diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 43e5896801a..39b5a4173e5 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -31,20 +31,27 @@ def check_integrity(fpath, md5=None): return True -def download_url(url, root, filename, md5): - from six.moves import urllib - - root = os.path.expanduser(root) - fpath = os.path.join(root, filename) - +def makedir_exist_ok(dirpath): + """ + Python2 support for os.makedirs(.., exist_ok=True) + """ try: - os.makedirs(root) + os.makedirs(dirpath) except OSError as e: if e.errno == errno.EEXIST: pass else: raise + +def download_url(url, root, filename, md5): + from six.moves import urllib + + root = os.path.expanduser(root) + fpath = os.path.join(root, filename) + + makedir_exist_ok(root) + # downloads file if os.path.isfile(fpath) and check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath)