Skip to content

Commit

Permalink
[WIP] Add tests for datasets (#966)
Browse files Browse the repository at this point in the history
* WIP

* WIP: minor improvements

* Add tests

* Fix typo

* Use download_and_extract on caltech, cifar and omniglot

* Add a print message during extraction

* Remove EMNIST from test
  • Loading branch information
fmassa authored May 29, 2019
1 parent 2b3a1b6 commit c59f047
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 68 deletions.
38 changes: 38 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import PIL
import shutil
import tempfile
import unittest

import torchvision


class Tester(unittest.TestCase):

def test_mnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)

def test_kmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)

def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)


if __name__ == '__main__':
unittest.main()
44 changes: 44 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile
import tarfile
import gzip

TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg')
Expand Down Expand Up @@ -41,6 +44,47 @@ def test_download_url_retry_http(self):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)

def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content'
shutil.rmtree(temp_dir)

def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)

def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
44 changes: 16 additions & 28 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path

from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, makedir_exist_ok


class Caltech101(VisionDataset):
Expand Down Expand Up @@ -109,27 +109,20 @@ def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)

with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)
Expand Down Expand Up @@ -204,17 +197,12 @@ def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")

# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")
11 changes: 2 additions & 9 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pickle

from .vision import VisionDataset
from .utils import download_url, check_integrity
from .utils import check_integrity, download_and_extract


class CIFAR10(VisionDataset):
Expand Down Expand Up @@ -144,17 +144,10 @@ def _check_integrity(self):
return True

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url(self.url, self.root, self.filename, self.tgz_md5)

# extract file
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)

def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
Expand Down
29 changes: 5 additions & 24 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from PIL import Image
import os
import os.path
import gzip
import numpy as np
import torch
import codecs
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, extract_file, makedir_exist_ok


class MNIST(VisionDataset):
Expand Down Expand Up @@ -120,15 +119,6 @@ def _check_exists(self):
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))

@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""

Expand All @@ -141,9 +131,7 @@ def download(self):
# download files
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)
download_and_extract(url, root=self.raw_folder, filename=filename)

# process and save as torch files
print('Processing...')
Expand Down Expand Up @@ -262,7 +250,6 @@ def _test_file(split):
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
import zipfile

if self._check_exists():
return
Expand All @@ -271,18 +258,12 @@ def download(self):
makedir_exist_ok(self.processed_folder)

# download files
filename = self.url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)

print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(self.raw_folder)
os.unlink(file_path)
print('Downloading and extracting zip archive')
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)

# process and save as torch files
for split in self.splits:
Expand Down
9 changes: 2 additions & 7 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from os.path import join
import os
from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files
from .utils import download_and_extract, check_integrity, list_dir, list_files


class Omniglot(VisionDataset):
Expand Down Expand Up @@ -81,19 +81,14 @@ def _check_integrity(self):
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])

def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
47 changes: 47 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import os.path
import hashlib
import gzip
import errno
import tarfile
import zipfile

from torch.utils.model_zoo import tqdm


Expand Down Expand Up @@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768):
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()


def _is_tar(filename):
return filename.endswith(".tar")


def _is_targz(filename):
return filename.endswith(".tar.gz")


def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")


def _is_zip(filename):
return filename.endswith(".zip")


def extract_file(from_path, to_path, remove_finished=False):
if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))

if remove_finished:
os.unlink(from_path)


def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)

0 comments on commit c59f047

Please sign in to comment.