diff --git a/test/common_utils.py b/test/common_utils.py index 6f9dc9af932..7e16864d56c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -10,6 +10,7 @@ import warnings import __main__ import random +import inspect from numbers import Number from torch._six import string_classes @@ -401,3 +402,20 @@ def disable_console_output(): stack.enter_context(contextlib.redirect_stdout(devnull)) stack.enter_context(contextlib.redirect_stderr(devnull)) yield + + +def call_args_to_kwargs_only(call_args, *callable_or_arg_names): + callable_or_arg_name = callable_or_arg_names[0] + if callable(callable_or_arg_name): + argspec = inspect.getfullargspec(callable_or_arg_name) + arg_names = argspec.args + if isinstance(callable_or_arg_name, type): + # remove self + arg_names.pop(0) + else: + arg_names = callable_or_arg_names + + args, kwargs = call_args + kwargs_only = kwargs.copy() + kwargs_only.update(dict(zip(arg_names, args))) + return kwargs_only diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index b1a8e1eda0f..be9299483fc 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -8,8 +8,10 @@ import warnings from torch._utils_internal import get_file_path_2 from urllib.error import URLError +import itertools +import lzma -from common_utils import get_tmp_dir +from common_utils import get_tmp_dir, call_args_to_kwargs_only TEST_FILE = get_file_path_2( @@ -100,6 +102,114 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): mock.assert_called_once_with(id, root, filename, md5) + def test_detect_file_type(self): + for file, expected in [ + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), + ]: + with self.subTest(file=file): + self.assertSequenceEqual(utils._detect_file_type(file), expected) + + def test_detect_file_type_no_ext(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo") + + def test_detect_file_type_to_many_exts(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar.tar.gz") + + def test_detect_file_type_unknown_archive_type(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar.gz") + + def test_detect_file_type_unknown_compression(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.tar.baz") + + def test_detect_file_type_unknown_partial_ext(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar") + + def test_decompress_gzip(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils._decompress(compressed) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + + def test_decompress_lzma(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.xz" + + with lzma.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + + def test_decompress_no_compression(self): + with self.assertRaises(RuntimeError): + utils._decompress("foo.tar") + + def test_decompress_remove_finished(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir, remove_finished=True) + + self.assertFalse(os.path.exists(compressed)) + + def test_extract_archive_defer_to_decompress(self): + filename = "foo" + for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): + with self.subTest(ext=ext, remove_finished=remove_finished): + with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock: + file = f"{filename}{ext}" + utils.extract_archive(file, remove_finished=remove_finished) + + mock.assert_called_once() + self.assertEqual( + call_args_to_kwargs_only(mock.call_args, utils._decompress), + dict(from_path=file, to_path=filename, remove_finished=remove_finished), + ) + def test_extract_zip(self): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") @@ -170,26 +280,6 @@ def create_archive(root, ext, mode, content="this is the content"): with open(file, "r") as fh: self.assertEqual(fh.read(), content) - def test_extract_gzip(self): - def create_compressed(root, content="this is the content"): - file = os.path.join(root, "file") - compressed = f"{file}.gz" - - with gzip.GzipFile(compressed, "wb") as fh: - fh.write(content.encode()) - - return compressed, file, content - - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) - - utils.extract_archive(compressed, temp_dir) - - self.assertTrue(os.path.exists(file)) - - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - def test_verify_str_arg(self): self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index e2ac22200d3..8da26d6e98e 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -4,12 +4,15 @@ import gzip import re import tarfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar +from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple from urllib.parse import urlparse import zipfile +import lzma +import contextlib import urllib import urllib.request import urllib.error +import pathlib import torch from torch.utils.model_zoo import tqdm @@ -242,56 +245,145 @@ def _save_response_content( pbar.close() -def _is_tarxz(filename: str) -> bool: - return filename.endswith(".tar.xz") +def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: + tar.extractall(to_path) -def _is_tar(filename: str) -> bool: - return filename.endswith(".tar") +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + ".xz": zipfile.ZIP_LZMA, +} -def _is_targz(filename: str) -> bool: - return filename.endswith(".tar.gz") +def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) -def _is_tgz(filename: str) -> bool: - return filename.endswith(".tgz") +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { + ".tar": _extract_tar, + ".zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")} -def _is_gzip(filename: str) -> bool: - return filename.endswith(".gz") and not filename.endswith(".tar.gz") +def _verify_archive_type(archive_type: str) -> None: + if archive_type not in _ARCHIVE_EXTRACTORS.keys(): + valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys()) + raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.") -def _is_zip(filename: str) -> bool: - return filename.endswith(".zip") +def _verify_compression(compression: str) -> None: + if compression not in _COMPRESSED_FILE_OPENERS.keys(): + valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys()) + raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.") -def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + path = pathlib.Path(file) + suffix = path.suffix + suffixes = pathlib.Path(file).suffixes + if not suffixes: + raise RuntimeError( + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." + ) + elif len(suffixes) > 2: + raise RuntimeError( + "Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead." + ) + elif len(suffixes) == 2: + # if we have exactly two suffixes we assume the first one is the archive type and the second on is the + # compression + archive_type, compression = suffixes + _verify_archive_type(archive_type) + _verify_compression(compression) + return "".join(suffixes), archive_type, compression + + # check if the suffix is a known alias + with contextlib.suppress(KeyError): + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + + # check if the suffix is an archive type + with contextlib.suppress(RuntimeError): + _verify_archive_type(suffix) + return suffix, suffix, None + + # check if the suffix is a compression + with contextlib.suppress(RuntimeError): + _verify_compression(suffix) + return suffix, None, suffix + + raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.") + + +def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + suffix, archive_type, compression = _detect_file_type(from_path) + if not compression: + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") + if to_path is None: - to_path = os.path.dirname(from_path) + to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") - if _is_tar(from_path): - with tarfile.open(from_path, 'r') as tar: - tar.extractall(path=to_path) - elif _is_targz(from_path) or _is_tgz(from_path): - with tarfile.open(from_path, 'r:gz') as tar: - tar.extractall(path=to_path) - elif _is_tarxz(from_path): - with tarfile.open(from_path, 'r:xz') as tar: - tar.extractall(path=to_path) - elif _is_gzip(from_path): - to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) - with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: - out_f.write(zip_f.read()) - elif _is_zip(from_path): - with zipfile.ZipFile(from_path, 'r') as z: - z.extractall(to_path) - else: - raise ValueError("Extraction of {} not supported".format(from_path)) + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) if remove_finished: os.remove(from_path) + return to_path + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + if to_path is None: + to_path = os.path.dirname(from_path) + + suffix, archive_type, compression = _detect_file_type(from_path) + if not archive_type: + return _decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), + remove_finished=remove_finished, + ) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] + + extractor(from_path, to_path, compression) + + return to_path + def download_and_extract_archive( url: str,