Skip to content

Commit

Permalink
[fbsync] Port test_datasets_utils to pytest (#4114)
Browse files Browse the repository at this point in the history
Reviewed By: fmassa

Differential Revision: D29395325

fbshipit-source-id: f0313af872dc410cd9d5923c2f360133e6dc82b4
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jun 28, 2021
1 parent 4673788 commit ecd7404
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 159 deletions.
17 changes: 0 additions & 17 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,6 @@ def disable_console_output():
yield


def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
callable_or_arg_name = callable_or_arg_names[0]
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name)
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names

args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only


def cpu_and_gpu():
import pytest # noqa
return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda))
Expand Down
211 changes: 69 additions & 142 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import bz2
import os
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
import pytest
import zipfile
import tarfile
import gzip
Expand All @@ -12,31 +11,32 @@
import itertools
import lzma

from common_utils import get_tmp_dir, call_args_to_kwargs_only
from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS


TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):
class TestDatasetsUtils:

def test_check_md5(self):
fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_md5(fpath, correct_md5))
self.assertFalse(utils.check_md5(fpath, false_md5))
assert utils.check_md5(fpath, correct_md5)
assert not utils.check_md5(fpath, false_md5)

def test_check_integrity(self):
existing_fpath = TEST_FILE
nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))
assert utils.check_integrity(existing_fpath, correct_md5)
assert not utils.check_integrity(existing_fpath, false_md5)
assert utils.check_integrity(existing_fpath)
assert not utils.check_integrity(nonexisting_fpath)

def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
Expand All @@ -50,44 +50,38 @@ def test_get_google_drive_file_id_invalid_url(self):

assert utils._get_google_drive_file_id(url) is None

def test_detect_file_type(self):
for file, expected in [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)

def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")

def test_detect_file_type_unknown_partial_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")

def test_decompress_bz2(self):
@pytest.mark.parametrize('file, expected', [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None))])
def test_detect_file_type(self, file, expected):
assert utils._detect_file_type(file) == expected

@pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"])
def test_detect_file_type_incompatible(self, file):
# tests detect file type for no extension, unknown compression and unknown partial extension
with pytest.raises(RuntimeError):
utils._detect_file_type(file)

@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.bz2"
compressed = f"{file}{extension}"
compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]

with bz2.open(compressed, "wb") as fh:
with compressed_file_opener(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content
Expand All @@ -97,53 +91,13 @@ def create_compressed(root, content="this is the content"):

utils._decompress(compressed)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils._decompress(compressed)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_lzma(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.xz"

with lzma.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
assert fh.read() == content

def test_decompress_no_compression(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self):
Expand All @@ -161,21 +115,18 @@ def create_compressed(root, content="this is the content"):

utils.extract_archive(compressed, temp_dir, remove_finished=True)

self.assertFalse(os.path.exists(compressed))
assert not os.path.exists(compressed)

def test_extract_archive_defer_to_decompress(self):
@pytest.mark.parametrize('extension', [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False])
def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
filename = "foo"
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):
with self.subTest(ext=ext, remove_finished=remove_finished):
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock:
file = f"{filename}{ext}"
utils.extract_archive(file, remove_finished=remove_finished)

mock.assert_called_once()
self.assertEqual(
call_args_to_kwargs_only(mock.call_args, utils._decompress),
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
)
file = f"{filename}{extension}"

mocked = mocker.patch("torchvision.datasets.utils._decompress")
utils.extract_archive(file, remove_finished=remove_finished)

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self):
def create_archive(root, content="this is the content"):
Expand All @@ -192,16 +143,18 @@ def create_archive(root, content="this is the content"):

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
assert fh.read() == content

def test_extract_tar(self):
def create_archive(root, ext, mode, content="this is the content"):
@pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")
archive = os.path.join(root, f"archive{extension}")

with open(src, "w") as fh:
fh.write(content)
Expand All @@ -211,47 +164,21 @@ def create_archive(root, ext, mode, content="this is the content"):

return archive, dst, content

for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode)

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_tar_xz(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")

with open(src, "w") as fh:
fh.write(content)

with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))

return archive, dst, content

for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode)
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, extension, mode)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
with open(file, "r") as fh:
assert fh.read() == content

def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
assert "a" == utils.verify_str_arg("a", "arg", ("a",))
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])

0 comments on commit ecd7404

Please sign in to comment.