From 563cfa2b4da9f23399f7e547ebf3283fc7831052 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 23 Feb 2021 19:20:08 +0100 Subject: [PATCH 01/16] generalize extract_archive --- torchvision/datasets/utils.py | 160 +++++++++++++++++++++++++++------- 1 file changed, 127 insertions(+), 33 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 1bd3d3c8053..12b1bd65969 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -4,9 +4,11 @@ import gzip import re import tarfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar +from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple from urllib.parse import urlparse import zipfile +import lzma +import contextlib import torch from torch.utils.model_zoo import tqdm @@ -231,56 +233,148 @@ def _save_response_content( pbar.close() -def _is_tarxz(filename: str) -> bool: - return filename.endswith(".tar.xz") +def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with tarfile.open(from_path, f"r:{compression}" if compression else "r") as tar: + tar.extractall(to_path) -def _is_tar(filename: str) -> bool: - return filename.endswith(".tar") +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + "xz": zipfile.ZIP_LZMA, +} -def _is_targz(filename: str) -> bool: - return filename.endswith(".tar.gz") +def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) -def _is_tgz(filename: str) -> bool: - return filename.endswith(".tgz") +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { + "tar": _extract_tar, + "zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {"gz": gzip.open, "xz": lzma.open} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {"tgz": ("tar", "gz")} -def _is_gzip(filename: str) -> bool: - return filename.endswith(".gz") and not filename.endswith(".tar.gz") +def _verify_compression(compression: str) -> None: + if compression not in _COMPRESSED_FILE_OPENERS.keys(): + valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys()) + raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.") -def _is_zip(filename: str) -> bool: - return filename.endswith(".zip") +def _verify_archive_type(archive_type: str) -> None: + if archive_type not in _ARCHIVE_EXTRACTORS.keys(): + valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys()) + raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.") -def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: +def _parse_ext(ext: str) -> Tuple[Optional[str], Optional[str]]: + exts = ext.split(".") + if len(exts) > 2: + raise RuntimeError + + if len(exts) == 2: + archive_type, compression = exts + _verify_archive_type(archive_type) + _verify_compression(compression) + return archive_type, compression + + with contextlib.suppress(KeyError): + return _FILE_TYPE_ALIASES[ext] + + partial_ext = exts[0] + + with contextlib.suppress(RuntimeError): + _verify_archive_type(partial_ext) + return partial_ext, None + + with contextlib.suppress(RuntimeError): + _verify_compression(partial_ext) + return None, partial_ext + + raise RuntimeError + + +def _determine_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + try: + ext = file.split(".", 1)[1] + except IndexError as error: + raise RuntimeError from error + + return (ext, *_parse_ext(ext)) + + +def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + ext, archive_type, compression = _determine_file_type(from_path) + if not compression: + raise RuntimeError + if to_path is None: - to_path = os.path.dirname(from_path) + to_path = from_path.replace(ext, archive_type if archive_type is not None else "") - if _is_tar(from_path): - with tarfile.open(from_path, 'r') as tar: - tar.extractall(path=to_path) - elif _is_targz(from_path) or _is_tgz(from_path): - with tarfile.open(from_path, 'r:gz') as tar: - tar.extractall(path=to_path) - elif _is_tarxz(from_path): - with tarfile.open(from_path, 'r:xz') as tar: - tar.extractall(path=to_path) - elif _is_gzip(from_path): - to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) - with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: - out_f.write(zip_f.read()) - elif _is_zip(from_path): - with zipfile.ZipFile(from_path, 'r') as z: - z.extractall(to_path) - else: - raise ValueError("Extraction of {} not supported".format(from_path)) + try: + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + except KeyError as error: + raise RuntimeError from error + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) if remove_finished: os.remove(from_path) + return to_path + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + if to_path is None: + to_path = os.path.dirname(from_path) + + ext, archive_type, compression = _determine_file_type(from_path) + if not archive_type: + return decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(f".{ext}", "")), + remove_finished=remove_finished, + ) + + try: + extractor = _ARCHIVE_EXTRACTORS[archive_type] + except KeyError as error: + raise RuntimeError from error + + extractor(from_path, to_path, compression) + + return to_path + def download_and_extract_archive( url: str, From 7fafebb0f6b4c49bd72c4b5e0a0b4b8c432bce57 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Feb 2021 07:14:34 +0100 Subject: [PATCH 02/16] [test] re-enable extraction tests on windows --- test/test_datasets_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index f0edbaba08f..8e5e6b0b9b6 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -102,7 +102,6 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): mock.assert_called_once_with(id, root, filename, md5) - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_zip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.zip') as f: @@ -114,7 +113,6 @@ def test_extract_zip(self): data = nf.read() self.assertEqual(data, 'this is the content') - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar(self): for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: @@ -130,7 +128,6 @@ def test_extract_tar(self): data = nf.read() self.assertEqual(data, 'this is the content') - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar_xz(self): for ext, mode in zip(['.tar.xz'], ['w:xz']): with get_tmp_dir() as temp_dir: @@ -146,7 +143,6 @@ def test_extract_tar_xz(self): data = nf.read() self.assertEqual(data, 'this is the content') - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_gzip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.gz') as f: From 17f9c8312c4d45642344ef493478059356d38e33 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Feb 2021 07:33:05 +0100 Subject: [PATCH 03/16] add tests for detect_file_type --- test/test_datasets_utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 8e5e6b0b9b6..0f2843c09e5 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -102,6 +102,39 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): mock.assert_called_once_with(id, root, filename, md5) + def test_detect_file_type(self): + for file, expected in [ + ("foo.tar.xz", ("tar.xz", "tar", "xz")), + ("foo.tar", ("tar", "tar", None)), + ("foo.tar.gz", ("tar.gz", "tar", "gz")), + ("foo.tgz", ("tgz", "tar", "gz")), + ("foo.gz", ("gz", None, "gz")), + ("foo.zip", ("zip", "zip", None)), + ("foo.xz", ("xz", None, "xz")), + ]: + with self.subTest(file=file): + self.assertSequenceEqual(utils._detect_file_type(file), expected) + + def test_detect_file_type_no_ext(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo") + + def test_detect_file_type_to_many_exts(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar.tar.gz") + + def test_detect_file_type_unknown_archive_type(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar.gz") + + def test_detect_file_type_unknown_compression(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.tar.baz") + + def test_detect_file_type_unknown_partial_ext(self): + with self.assertRaises(RuntimeError): + utils._detect_file_type("foo.bar") + def test_extract_zip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.zip') as f: From f783bcd52424013d4b983b0248b5a43e7c2a9a11 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Feb 2021 07:34:00 +0100 Subject: [PATCH 04/16] add error messages to detect_file_type --- torchvision/datasets/utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 12b1bd65969..d763eefc122 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -270,10 +270,18 @@ def _verify_archive_type(archive_type: str) -> None: raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.") +def _verify_compression(compression: str) -> None: + if compression not in _COMPRESSED_FILE_OPENERS.keys(): + valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys()) + raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.") + + def _parse_ext(ext: str) -> Tuple[Optional[str], Optional[str]]: exts = ext.split(".") if len(exts) > 2: - raise RuntimeError + raise RuntimeError( + "Archive type and compression detection only works for 1 or 2 extensions. " f"Got {len(exts)} instead." + ) if len(exts) == 2: archive_type, compression = exts @@ -294,14 +302,15 @@ def _parse_ext(ext: str) -> Tuple[Optional[str], Optional[str]]: _verify_compression(partial_ext) return None, partial_ext - raise RuntimeError + raise RuntimeError(f"Extension '{partial_ext}' is neither recognized as archive type nor as compression.") -def _determine_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: try: ext = file.split(".", 1)[1] except IndexError as error: - raise RuntimeError from error + msg = f"File '{file}' has no extensions that could be used to detect the archive type and compression." + raise RuntimeError(msg) from error return (ext, *_parse_ext(ext)) @@ -319,7 +328,7 @@ def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: b Returns: (str): Path to the decompressed file. """ - ext, archive_type, compression = _determine_file_type(from_path) + ext, archive_type, compression = _detect_file_type(from_path) if not compression: raise RuntimeError @@ -358,7 +367,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish if to_path is None: to_path = os.path.dirname(from_path) - ext, archive_type, compression = _determine_file_type(from_path) + ext, archive_type, compression = _detect_file_type(from_path) if not archive_type: return decompress( from_path, From a22abbcd99ab54ddccfe2d48bf9cd3bcc9564fa4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 10 Mar 2021 16:29:46 +0100 Subject: [PATCH 05/16] Revert "[test] re-enable extraction tests on windows" This reverts commit 7fafebb0f6b4c49bd72c4b5e0a0b4b8c432bce57. --- test/test_datasets_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 0f2843c09e5..5f6e7e6cf52 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -135,6 +135,7 @@ def test_detect_file_type_unknown_partial_ext(self): with self.assertRaises(RuntimeError): utils._detect_file_type("foo.bar") + @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_zip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.zip') as f: @@ -146,6 +147,7 @@ def test_extract_zip(self): data = nf.read() self.assertEqual(data, 'this is the content') + @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar(self): for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: @@ -161,6 +163,7 @@ def test_extract_tar(self): data = nf.read() self.assertEqual(data, 'this is the content') + @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar_xz(self): for ext, mode in zip(['.tar.xz'], ['w:xz']): with get_tmp_dir() as temp_dir: @@ -176,6 +179,7 @@ def test_extract_tar_xz(self): data = nf.read() self.assertEqual(data, 'this is the content') + @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_gzip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.gz') as f: From ff296395f7b1180e53141f086ae2890930dcb936 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Mar 2021 09:06:17 +0100 Subject: [PATCH 06/16] add utility functions for better mock call checking --- test/common_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/common_utils.py b/test/common_utils.py index 6f9dc9af932..7e16864d56c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -10,6 +10,7 @@ import warnings import __main__ import random +import inspect from numbers import Number from torch._six import string_classes @@ -401,3 +402,20 @@ def disable_console_output(): stack.enter_context(contextlib.redirect_stdout(devnull)) stack.enter_context(contextlib.redirect_stderr(devnull)) yield + + +def call_args_to_kwargs_only(call_args, *callable_or_arg_names): + callable_or_arg_name = callable_or_arg_names[0] + if callable(callable_or_arg_name): + argspec = inspect.getfullargspec(callable_or_arg_name) + arg_names = argspec.args + if isinstance(callable_or_arg_name, type): + # remove self + arg_names.pop(0) + else: + arg_names = callable_or_arg_names + + args, kwargs = call_args + kwargs_only = kwargs.copy() + kwargs_only.update(dict(zip(arg_names, args))) + return kwargs_only From 1ac42a57e0371b4dc243b6a0111d0df17b094185 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Mar 2021 09:07:35 +0100 Subject: [PATCH 07/16] add tests for decompress --- test/test_datasets_utils.py | 64 +++++++++++++++++++++++------------ torchvision/datasets/utils.py | 8 +---- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index d4b5c6b9b30..4f971b08660 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -8,8 +8,10 @@ import warnings from torch._utils_internal import get_file_path_2 from urllib.error import URLError +import itertools +import lzma -from common_utils import get_tmp_dir +from common_utils import get_tmp_dir, call_args_to_kwargs_only TEST_FILE = get_file_path_2( @@ -133,6 +135,46 @@ def test_detect_file_type_unknown_partial_ext(self): with self.assertRaises(RuntimeError): utils._detect_file_type("foo.bar") + def test_decopress_gzip(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils.decompress(compressed) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + + def test_decopress_lzma(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.xz" + + with lzma.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_zip(self): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") @@ -203,26 +245,6 @@ def create_archive(root, ext, mode, content="this is the content"): with open(file, "r") as fh: self.assertEqual(fh.read(), content) - def test_extract_gzip(self): - def create_compressed(root, content="this is the content"): - file = os.path.join(root, "file") - compressed = f"{file}.gz" - - with gzip.GzipFile(compressed, "wb") as fh: - fh.write(content.encode()) - - return compressed, file, content - - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) - - utils.extract_archive(compressed, temp_dir) - - self.assertTrue(os.path.exists(file)) - - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - def test_verify_str_arg(self): self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 3dfa3ae87a4..c63827f9e2c 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -267,12 +267,6 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {"tgz": ("tar", "gz")} -def _verify_compression(compression: str) -> None: - if compression not in _COMPRESSED_FILE_OPENERS.keys(): - valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys()) - raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.") - - def _verify_archive_type(archive_type: str) -> None: if archive_type not in _ARCHIVE_EXTRACTORS.keys(): valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys()) @@ -342,7 +336,7 @@ def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: b raise RuntimeError if to_path is None: - to_path = from_path.replace(ext, archive_type if archive_type is not None else "") + to_path = from_path.replace(f".{ext}", f".{archive_type}" if archive_type is not None else "") try: compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] From b10036c130467c6640e179940e329783dfd52bba Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Mar 2021 09:35:01 +0100 Subject: [PATCH 08/16] simplify logic by using pathlib --- test/test_datasets_utils.py | 26 +++++++++---- torchvision/datasets/utils.py | 71 +++++++++++++++++------------------ 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 4f971b08660..2216fcd2cd9 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -104,13 +104,13 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): def test_detect_file_type(self): for file, expected in [ - ("foo.tar.xz", ("tar.xz", "tar", "xz")), - ("foo.tar", ("tar", "tar", None)), - ("foo.tar.gz", ("tar.gz", "tar", "gz")), - ("foo.tgz", ("tgz", "tar", "gz")), - ("foo.gz", ("gz", None, "gz")), - ("foo.zip", ("zip", "zip", None)), - ("foo.xz", ("xz", None, "xz")), + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), ]: with self.subTest(file=file): self.assertSequenceEqual(utils._detect_file_type(file), expected) @@ -175,6 +175,18 @@ def create_compressed(root, content="this is the content"): with open(file, "r") as fh: self.assertEqual(fh.read(), content) + def test_extract_archive_defer_to_decompress(self): + filename = "foo" + for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): + with self.subTest(ext=ext, remove_finished=remove_finished): + with unittest.mock.patch("torchvision.datasets.utils.decompress") as mock: + file = f"{filename}{ext}" + utils.extract_archive(file, remove_finished=remove_finished) + + mock.assert_called_once() + self.assertEqual(call_args_to_kwargs_only(mock.call_args, utils.decompress), + dict(from_path=file, to_path=filename, remove_finished=remove_finished)) + def test_extract_zip(self): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index c63827f9e2c..a18df3c5422 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -12,6 +12,7 @@ import urllib import urllib.request import urllib.error +import pathlib import torch from torch.utils.model_zoo import tqdm @@ -243,12 +244,12 @@ def _save_response_content( def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: - with tarfile.open(from_path, f"r:{compression}" if compression else "r") as tar: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: tar.extractall(to_path) _ZIP_COMPRESSION_MAP: Dict[str, int] = { - "xz": zipfile.ZIP_LZMA, + ".xz": zipfile.ZIP_LZMA, } @@ -260,11 +261,11 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { - "tar": _extract_tar, - "zip": _extract_zip, + ".tar": _extract_tar, + ".zip": _extract_zip, } -_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {"gz": gzip.open, "xz": lzma.open} -_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {"tgz": ("tar", "gz")} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")} def _verify_archive_type(archive_type: str) -> None: @@ -279,43 +280,41 @@ def _verify_compression(compression: str) -> None: raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.") -def _parse_ext(ext: str) -> Tuple[Optional[str], Optional[str]]: - exts = ext.split(".") - if len(exts) > 2: +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + path = pathlib.Path(file) + suffix = path.suffix + suffixes = pathlib.Path(file).suffixes + if not suffixes: raise RuntimeError( - "Archive type and compression detection only works for 1 or 2 extensions. " f"Got {len(exts)} instead." + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." ) - - if len(exts) == 2: - archive_type, compression = exts + elif len(suffixes) > 2: + raise RuntimeError( + "Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead." + ) + elif len(suffixes) == 2: + # if we have exactly two suffixes we assume the first one is the archive type and the second on is the + # compression + archive_type, compression = suffixes _verify_archive_type(archive_type) _verify_compression(compression) - return archive_type, compression + return "".join(suffixes), archive_type, compression + # check if the suffix is a known alias with contextlib.suppress(KeyError): - return _FILE_TYPE_ALIASES[ext] - - partial_ext = exts[0] + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + # check if the suffix is an archive type with contextlib.suppress(RuntimeError): - _verify_archive_type(partial_ext) - return partial_ext, None + _verify_archive_type(suffix) + return suffix, suffix, None + # check if the suffix is a compression with contextlib.suppress(RuntimeError): - _verify_compression(partial_ext) - return None, partial_ext - - raise RuntimeError(f"Extension '{partial_ext}' is neither recognized as archive type nor as compression.") - - -def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: - try: - ext = file.split(".", 1)[1] - except IndexError as error: - msg = f"File '{file}' has no extensions that could be used to detect the archive type and compression." - raise RuntimeError(msg) from error + _verify_compression(suffix) + return suffix, None, suffix - return (ext, *_parse_ext(ext)) + raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.") def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: @@ -331,12 +330,12 @@ def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: b Returns: (str): Path to the decompressed file. """ - ext, archive_type, compression = _detect_file_type(from_path) + suffix, archive_type, compression = _detect_file_type(from_path) if not compression: raise RuntimeError if to_path is None: - to_path = from_path.replace(f".{ext}", f".{archive_type}" if archive_type is not None else "") + to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") try: compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] @@ -370,11 +369,11 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish if to_path is None: to_path = os.path.dirname(from_path) - ext, archive_type, compression = _detect_file_type(from_path) + suffix, archive_type, compression = _detect_file_type(from_path) if not archive_type: return decompress( from_path, - os.path.join(to_path, os.path.basename(from_path).replace(f".{ext}", "")), + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), remove_finished=remove_finished, ) From e9510dfe613c6e0b12b16bbcce555dc6847e6f90 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Mar 2021 09:38:51 +0100 Subject: [PATCH 09/16] lint --- test/test_datasets_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 2216fcd2cd9..4b206a7ede3 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -184,8 +184,10 @@ def test_extract_archive_defer_to_decompress(self): utils.extract_archive(file, remove_finished=remove_finished) mock.assert_called_once() - self.assertEqual(call_args_to_kwargs_only(mock.call_args, utils.decompress), - dict(from_path=file, to_path=filename, remove_finished=remove_finished)) + self.assertEqual( + call_args_to_kwargs_only(mock.call_args, utils.decompress), + dict(from_path=file, to_path=filename, remove_finished=remove_finished), + ) def test_extract_zip(self): def create_archive(root, content="this is the content"): From 8bf663099240f3a79b22f4dfb2efd21fa4c213f8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:01:02 +0100 Subject: [PATCH 10/16] Apply suggestions from code review Co-authored-by: Francisco Massa --- test/test_datasets_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 4b206a7ede3..731e02a5be6 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -135,7 +135,7 @@ def test_detect_file_type_unknown_partial_ext(self): with self.assertRaises(RuntimeError): utils._detect_file_type("foo.bar") - def test_decopress_gzip(self): + def test_decompress_gzip(self): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}.gz" @@ -155,7 +155,7 @@ def create_compressed(root, content="this is the content"): with open(file, "r") as fh: self.assertEqual(fh.read(), content) - def test_decopress_lzma(self): + def test_decompress_lzma(self): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}.xz" From dc799a2f2c90fe480d9704bc410ae679d71a3088 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:04:19 +0100 Subject: [PATCH 11/16] make decompress private --- test/test_datasets_utils.py | 4 ++-- torchvision/datasets/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 731e02a5be6..2018cb60f4e 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -148,7 +148,7 @@ def create_compressed(root, content="this is the content"): with get_tmp_dir() as temp_dir: compressed, file, content = create_compressed(temp_dir) - utils.decompress(compressed) + utils._decompress(compressed) self.assertTrue(os.path.exists(file)) @@ -185,7 +185,7 @@ def test_extract_archive_defer_to_decompress(self): mock.assert_called_once() self.assertEqual( - call_args_to_kwargs_only(mock.call_args, utils.decompress), + call_args_to_kwargs_only(mock.call_args, utils._decompress), dict(from_path=file, to_path=filename, remove_finished=remove_finished), ) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index a18df3c5422..e2cd46d1d12 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -317,7 +317,7 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.") -def decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: +def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: r"""Decompress a file. The compression is automatically detected from the file name. @@ -371,7 +371,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish suffix, archive_type, compression = _detect_file_type(from_path) if not archive_type: - return decompress( + return _decompress( from_path, os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), remove_finished=remove_finished, From 26e4f83f42db5aaaaad163e60ad7aa2b7ed506a1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:07:00 +0100 Subject: [PATCH 12/16] remove unnecessary checks --- torchvision/datasets/utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index e2cd46d1d12..fa0ee2ea81a 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -337,10 +337,8 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: if to_path is None: to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") - try: - compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] - except KeyError as error: - raise RuntimeError from error + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: wfh.write(rfh.read()) @@ -377,10 +375,8 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish remove_finished=remove_finished, ) - try: - extractor = _ARCHIVE_EXTRACTORS[archive_type] - except KeyError as error: - raise RuntimeError from error + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] extractor(from_path, to_path, compression) From 2c9d0c6c7ce3275963c57546cdc3f0b2080c366c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:09:37 +0100 Subject: [PATCH 13/16] add error message --- torchvision/datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index fa0ee2ea81a..beb47fa6c9d 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -332,7 +332,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: """ suffix, archive_type, compression = _detect_file_type(from_path) if not compression: - raise RuntimeError + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") if to_path is None: to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") From 56b577061c1eada2682d7ce5b85824889c6205ec Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:14:32 +0100 Subject: [PATCH 14/16] fix mocking --- test/test_datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 2018cb60f4e..ee56b285df5 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -179,7 +179,7 @@ def test_extract_archive_defer_to_decompress(self): filename = "foo" for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): with self.subTest(ext=ext, remove_finished=remove_finished): - with unittest.mock.patch("torchvision.datasets.utils.decompress") as mock: + with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock: file = f"{filename}{ext}" utils.extract_archive(file, remove_finished=remove_finished) From 15c559db871956c3788d927fe9e91bcaf72222d7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:20:02 +0100 Subject: [PATCH 15/16] add remaining tests --- test/test_datasets_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index ee56b285df5..2a27f5e50a2 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -175,6 +175,28 @@ def create_compressed(root, content="this is the content"): with open(file, "r") as fh: self.assertEqual(fh.read(), content) + def test_decompress_no_compression(self): + with self.assertRaises(RuntimeError): + utils._decompress("foo.tar") + + def test_decompress_remove_finished(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + with get_tmp_dir() as temp_dir: + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir, remove_finished=True) + + self.assertFalse(os.path.exists(compressed)) + + def test_extract_archive_defer_to_decompress(self): filename = "foo" for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): From f4cf171a1c0bca55753212b27fb7a917277111bc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Mar 2021 08:21:06 +0100 Subject: [PATCH 16/16] lint --- test/test_datasets_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 2a27f5e50a2..be9299483fc 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -196,7 +196,6 @@ def create_compressed(root, content="this is the content"): self.assertFalse(os.path.exists(compressed)) - def test_extract_archive_defer_to_decompress(self): filename = "foo" for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):