Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate extraction and decompression logic in datasets.utils.extract_archive #3443

Merged
merged 20 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import __main__
import random
import inspect

from numbers import Number
from torch._six import string_classes
Expand Down Expand Up @@ -401,3 +402,20 @@ def disable_console_output():
stack.enter_context(contextlib.redirect_stdout(devnull))
stack.enter_context(contextlib.redirect_stderr(devnull))
yield


def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
callable_or_arg_name = callable_or_arg_names[0]
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name)
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names

args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only
132 changes: 111 additions & 21 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import warnings
from torch._utils_internal import get_file_path_2
from urllib.error import URLError
import itertools
import lzma

from common_utils import get_tmp_dir
from common_utils import get_tmp_dir, call_args_to_kwargs_only


TEST_FILE = get_file_path_2(
Expand Down Expand Up @@ -100,6 +102,114 @@ def test_download_url_dispatch_download_from_google_drive(self, mock):

mock.assert_called_once_with(id, root, filename, md5)

def test_detect_file_type(self):
for file, expected in [
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)

def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_to_many_exts(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.tar.gz")

def test_detect_file_type_unknown_archive_type(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.gz")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")

def test_detect_file_type_unknown_partial_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")

def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils._decompress(compressed)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_lzma(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.xz"

with lzma.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_no_compression(self):
with self.assertRaises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir, remove_finished=True)

self.assertFalse(os.path.exists(compressed))

def test_extract_archive_defer_to_decompress(self):
filename = "foo"
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):
with self.subTest(ext=ext, remove_finished=remove_finished):
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock:
file = f"{filename}{ext}"
utils.extract_archive(file, remove_finished=remove_finished)

mock.assert_called_once()
self.assertEqual(
call_args_to_kwargs_only(mock.call_args, utils._decompress),
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
)

def test_extract_zip(self):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
Expand Down Expand Up @@ -170,26 +280,6 @@ def create_archive(root, ext, mode, content="this is the content"):
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.GzipFile(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
Expand Down
158 changes: 125 additions & 33 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
from urllib.parse import urlparse
import zipfile
import lzma
import contextlib
import urllib
import urllib.request
import urllib.error
import pathlib

import torch
from torch.utils.model_zoo import tqdm
Expand Down Expand Up @@ -242,56 +245,145 @@ def _save_response_content(
pbar.close()


def _is_tarxz(filename: str) -> bool:
return filename.endswith(".tar.xz")
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)


def _is_tar(filename: str) -> bool:
return filename.endswith(".tar")
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
".xz": zipfile.ZIP_LZMA,
}


def _is_targz(filename: str) -> bool:
return filename.endswith(".tar.gz")
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


def _is_tgz(filename: str) -> bool:
return filename.endswith(".tgz")
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}


def _is_gzip(filename: str) -> bool:
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _verify_archive_type(archive_type: str) -> None:
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")


def _is_zip(filename: str) -> bool:
return filename.endswith(".zip")
def _verify_compression(compression: str) -> None:
if compression not in _COMPRESSED_FILE_OPENERS.keys():
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
path = pathlib.Path(file)
suffix = path.suffix
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
elif len(suffixes) > 2:
raise RuntimeError(
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
)
Comment on lines +293 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a bug for downloads with a period in the file name. For example, https://landcover.ai/download/landcover.ai.v1.zip.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on a PR to fix this, will ping you when it's done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #4099

elif len(suffixes) == 2:
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
# compression
archive_type, compression = suffixes
_verify_archive_type(archive_type)
_verify_compression(compression)
return "".join(suffixes), archive_type, compression

# check if the suffix is a known alias
with contextlib.suppress(KeyError):
return (suffix, *_FILE_TYPE_ALIASES[suffix])

# check if the suffix is an archive type
with contextlib.suppress(RuntimeError):
_verify_archive_type(suffix)
return suffix, suffix, None

# check if the suffix is a compression
with contextlib.suppress(RuntimeError):
_verify_compression(suffix)
return suffix, None, suffix

raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
r"""Decompress a file.

The compression is automatically detected from the file name.

Args:
from_path (str): Path to the file to be decompressed.
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
remove_finished (bool): If ``True``, remove the file after the extraction.

Returns:
(str): Path to the decompressed file.
"""
suffix, archive_type, compression = _detect_file_type(from_path)
if not compression:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = os.path.dirname(from_path)
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")

if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]

with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
wfh.write(rfh.read())

if remove_finished:
os.remove(from_path)

return to_path


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
"""Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed
but not an archive the call is dispatched to :func:`decompress`.

Args:
from_path (str): Path to the file to be extracted.
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
used.
remove_finished (bool): If ``True``, remove the file after the extraction.

Returns:
(str): Path to the directory the file was extracted to.
"""
if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]

extractor(from_path, to_path, compression)

return to_path


def download_and_extract_archive(
url: str,
Expand Down