From 697b6d6e718ebec777061a803582864e31486dba Mon Sep 17 00:00:00 2001 From: EL MEHDI AGUNAOU <56029953+Mehdi2402@users.noreply.github.com> Date: Tue, 24 Jan 2023 13:53:07 +0100 Subject: [PATCH] iter_archive on zipfiles with better compression type check (#3379) * iter_archive on zipfiles with better compression type check * Update src/datasets/utils/download_manager.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * tests for iter_archive on streaming_dl_manager * function names corrections * change host and other corrections * Refactor code * Fix bad merge * Tests * Style fixes Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Co-authored-by: mariosasko --- src/datasets/download/download_manager.py | 113 +++++++++++++++++- .../download/streaming_download_manager.py | 91 ++++++++++---- tests/fixtures/files.py | 8 ++ tests/test_download_manager.py | 12 +- tests/test_streaming_download_manager.py | 16 ++- 5 files changed, 203 insertions(+), 37 deletions(-) diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index 791ebd8d740..b63ae0cbe13 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -22,8 +22,10 @@ import tarfile import time import warnings +import zipfile from datetime import datetime from functools import partial +from itertools import chain from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union from .. import config @@ -38,6 +40,40 @@ logger = get_logger(__name__) +BASE_KNOWN_EXTENSIONS = [ + "txt", + "csv", + "json", + "jsonl", + "tsv", + "conll", + "conllu", + "orig", + "parquet", + "pkl", + "pickle", + "rel", + "xml", +] +MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL = { + bytes.fromhex("504B0304"): "zip", + bytes.fromhex("504B0506"): "zip", # empty archive + bytes.fromhex("504B0708"): "zip", # spanned archive + bytes.fromhex("425A68"): "bz2", + bytes.fromhex("1F8B"): "gzip", + bytes.fromhex("FD377A585A00"): "xz", + bytes.fromhex("04224D18"): "lz4", + bytes.fromhex("28B52FFD"): "zstd", +} +MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL = { + b"Rar!": "rar", +} +MAGIC_NUMBER_MAX_LENGTH = max( + len(magic_number) + for magic_number in chain(MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL, MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL) +) + + class DownloadMode(enum.Enum): """`Enum` for how to treat pre-existing downloads and data. @@ -69,6 +105,48 @@ def help_message(self): return "Use 'DownloadMode' instead." +def _get_path_extension(path: str) -> str: + # Get extension: train.json.gz -> gz + extension = path.split(".")[-1] + # Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz + # Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt + for symb in "?-_": + extension = extension.split(symb)[0] + return extension + + +def _get_extraction_protocol_with_magic_number(f) -> Optional[str]: + """read the magic number from a file-like object and return the compression protocol""" + # Check if the file object is seekable even before reading the magic number (to avoid https://bugs.python.org/issue26440) + try: + f.seek(0) + except (AttributeError, io.UnsupportedOperation): + return None + magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH) + f.seek(0) + for i in range(MAGIC_NUMBER_MAX_LENGTH): + compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) + if compression is not None: + return compression + compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) + if compression is not None: + raise NotImplementedError(f"Compression protocol '{compression}' not implemented.") + + +def _get_extraction_protocol(path: str) -> Optional[str]: + path = str(path) + extension = _get_path_extension(path) + # TODO(mariosasko): The below check will be useful once we can preserve the original extension in the new cache layout (use the `filename` parameter of `hf_hub_download`) + if ( + extension in BASE_KNOWN_EXTENSIONS + or extension in ["tgz", "tar"] + or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")) + ): + return None + with open(path, "rb") as f: + return _get_extraction_protocol_with_magic_number(f) + + class _IterableFromGenerator(Iterable): """Utility class to create an iterable from a generator function, in order to reset the generator when needed.""" @@ -84,8 +162,8 @@ def __iter__(self): class ArchiveIterable(_IterableFromGenerator): """An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`""" - @classmethod - def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: + @staticmethod + def _iter_tar(f): stream = tarfile.open(fileobj=f, mode="r|*") for tarinfo in stream: file_path = tarinfo.name @@ -93,7 +171,7 @@ def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: continue if file_path is None: continue - if os.path.basename(file_path).startswith(".") or os.path.basename(file_path).startswith("__"): + if os.path.basename(file_path).startswith((".", "__")): # skipping hidden files continue file_obj = stream.extractfile(tarinfo) @@ -101,10 +179,37 @@ def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: stream.members = [] del stream + @staticmethod + def _iter_zip(f): + zipf = zipfile.ZipFile(f) + for member in zipf.infolist(): + file_path = member.filename + if member.is_dir(): + continue + if file_path is None: + continue + if os.path.basename(file_path).startswith((".", "__")): + # skipping hidden files + continue + file_obj = zipf.open(member) + yield file_path, file_obj + + @classmethod + def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: + compression = _get_extraction_protocol_with_magic_number(f) + if compression == "zip": + yield from cls._iter_zip(f) + else: + yield from cls._iter_tar(f) + @classmethod def _iter_from_path(cls, urlpath: str) -> Generator[Tuple, None, None]: + compression = _get_extraction_protocol(urlpath) with open(urlpath, "rb") as f: - yield from cls._iter_from_fileobj(f) + if compression == "zip": + yield from cls._iter_zip(f) + else: + yield from cls._iter_tar(f) @classmethod def from_buf(cls, fileobj) -> "ArchiveIterable": diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 5ba7cb34fd3..af6766bebab 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -6,6 +6,7 @@ import tarfile import time import xml.dom.minidom +import zipfile from asyncio import TimeoutError from io import BytesIO from itertools import chain @@ -151,7 +152,7 @@ def xexists(urlpath: str, use_auth_token: Optional[Union[str, bool]] = None): `bool` """ - main_hop, *rest_hops = str(urlpath).split("::") + main_hop, *rest_hops = _as_str(urlpath).split("::") if is_local_path(main_hop): return os.path.exists(main_hop) else: @@ -381,13 +382,28 @@ def read_with_retries(*args, **kwargs): file_obj.read = read_with_retries +def _get_path_extension(path: str) -> str: + # Get extension: https://foo.bar/train.json.gz -> gz + extension = path.split(".")[-1] + # Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz + # Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt + for symb in "?-_": + extension = extension.split(symb)[0] + return extension + + def _get_extraction_protocol_with_magic_number(f) -> Optional[str]: """read the magic number from a file-like object and return the compression protocol""" + # Check if the file object is seekable even before reading the magic number (to avoid https://bugs.python.org/issue26440) + try: + f.seek(0) + except (AttributeError, io.UnsupportedOperation): + return None magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH) f.seek(0) for i in range(MAGIC_NUMBER_MAX_LENGTH): compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) - if compression is not None: # TODO(QL): raise an error for .tar.gz files as in _get_extraction_protocol + if compression is not None: return compression compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) if compression is not None: @@ -396,28 +412,17 @@ def _get_extraction_protocol_with_magic_number(f) -> Optional[str]: def _get_extraction_protocol(urlpath: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]: # get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz + urlpath = str(urlpath) path = urlpath.split("::")[0] - # Get extension: https://foo.bar/train.json.gz -> gz - extension = path.split(".")[-1] - # Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz - # Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt - for symb in "?-_": - extension = extension.split(symb)[0] - if extension in BASE_KNOWN_EXTENSIONS: + extension = _get_path_extension(path) + if ( + extension in BASE_KNOWN_EXTENSIONS + or extension in ["tgz", "tar"] + or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")) + ): return None - elif extension in ["tgz", "tar"] or path.endswith(".tar.gz"): - raise NotImplementedError( - f"Extraction protocol for TAR archives like '{urlpath}' is not implemented in streaming mode. " - f"Please use `dl_manager.iter_archive` instead.\n\n" - f"Example usage:\n\n" - f"\turl = dl_manager.download(url)\n" - f"\ttar_archive_iterator = dl_manager.iter_archive(url)\n\n" - f"\tfor filename, file in tar_archive_iterator:\n" - f"\t\t..." - ) elif extension in COMPRESSION_EXTENSION_TO_PROTOCOL: return COMPRESSION_EXTENSION_TO_PROTOCOL[extension] - if is_remote_url(urlpath): # get headers and cookies for authentication on the HF Hub and for Google Drive urlpath, kwargs = _prepare_http_url_kwargs(urlpath, use_auth_token=use_auth_token) @@ -849,8 +854,8 @@ def __iter__(self): class ArchiveIterable(_IterableFromGenerator): """An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`""" - @classmethod - def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: + @staticmethod + def _iter_tar(f): stream = tarfile.open(fileobj=f, mode="r|*") for tarinfo in stream: file_path = tarinfo.name @@ -866,12 +871,39 @@ def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: stream.members = [] del stream + @staticmethod + def _iter_zip(f): + zipf = zipfile.ZipFile(f) + for member in zipf.infolist(): + file_path = member.filename + if member.is_dir(): + continue + if file_path is None: + continue + if os.path.basename(file_path).startswith((".", "__")): + # skipping hidden files + continue + file_obj = zipf.open(member) + yield file_path, file_obj + + @classmethod + def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: + compression = _get_extraction_protocol_with_magic_number(f) + if compression == "zip": + yield from cls._iter_zip(f) + else: + yield from cls._iter_tar(f) + @classmethod def _iter_from_urlpath( cls, urlpath: str, use_auth_token: Optional[Union[str, bool]] = None ) -> Generator[Tuple, None, None]: + compression = _get_extraction_protocol(urlpath, use_auth_token=use_auth_token) with xopen(urlpath, "rb", use_auth_token=use_auth_token) as f: - yield from cls._iter_from_fileobj(f) + if compression == "zip": + yield from cls._iter_zip(f) + else: + yield from cls._iter_tar(f) @classmethod def from_buf(cls, fileobj) -> "ArchiveIterable": @@ -995,6 +1027,19 @@ def extract(self, url_or_urls): def _extract(self, urlpath: str) -> str: urlpath = str(urlpath) protocol = _get_extraction_protocol(urlpath, use_auth_token=self.download_config.use_auth_token) + # get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz + path = urlpath.split("::")[0] + extension = _get_path_extension(path) + if extension in ["tgz", "tar"] or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")): + raise NotImplementedError( + f"Extraction protocol for TAR archives like '{urlpath}' is not implemented in streaming mode. " + f"Please use `dl_manager.iter_archive` instead.\n\n" + f"Example usage:\n\n" + f"\turl = dl_manager.download(url)\n" + f"\ttar_archive_iterator = dl_manager.iter_archive(url)\n\n" + f"\tfor filename, file in tar_archive_iterator:\n" + f"\t\t..." + ) if protocol is None: # no extraction return urlpath diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index b1dd1f8785c..f38db20a68d 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -408,6 +408,14 @@ def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory): return path +@pytest.fixture(scope="session") +def zip_nested_jsonl_path(zip_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(zip_jsonl_path, arcname=os.path.join("nested", os.path.basename(zip_jsonl_path))) + return path + + @pytest.fixture(scope="session") def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory): path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip" diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index 0dfc0a2098e..3673d43ce98 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -122,16 +122,20 @@ def _test_jsonl(path, file): assert num_items == 4 -def test_iter_archive_path(tar_jsonl_path): +@pytest.mark.parametrize("archive_jsonl", ["tar_jsonl_path", "zip_jsonl_path"]) +def test_iter_archive_path(archive_jsonl, request): + archive_jsonl_path = request.getfixturevalue(archive_jsonl) dl_manager = DownloadManager() - for num_jsonl, (path, file) in enumerate(dl_manager.iter_archive(tar_jsonl_path), start=1): + for num_jsonl, (path, file) in enumerate(dl_manager.iter_archive(archive_jsonl_path), start=1): _test_jsonl(path, file) assert num_jsonl == 2 -def test_iter_archive_file(tar_nested_jsonl_path): +@pytest.mark.parametrize("archive_nested_jsonl", ["tar_nested_jsonl_path", "zip_nested_jsonl_path"]) +def test_iter_archive_file(archive_nested_jsonl, request): + archive_nested_jsonl_path = request.getfixturevalue(archive_nested_jsonl) dl_manager = DownloadManager() - for num_tar, (path, file) in enumerate(dl_manager.iter_archive(tar_nested_jsonl_path), start=1): + for num_tar, (path, file) in enumerate(dl_manager.iter_archive(archive_nested_jsonl_path), start=1): for num_jsonl, (subpath, subfile) in enumerate(dl_manager.iter_archive(file), start=1): _test_jsonl(subpath, subfile) assert num_tar == 1 diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index 820291494db..1ec1216f282 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -826,9 +826,9 @@ def test_streaming_dl_manager_get_extraction_protocol_gg_drive(urlpath, expected "https://foo.bar/train.tar", ], ) -def test_streaming_dl_manager_get_extraction_protocol_throws(urlpath): +def test_streaming_dl_manager_extract_throws(urlpath): with pytest.raises(NotImplementedError): - _ = _get_extraction_protocol(urlpath) + _ = StreamingDownloadManager().extract(urlpath) @slow # otherwise it spams Google Drive and the CI gets banned @@ -873,9 +873,11 @@ def _test_jsonl(path, file): assert num_items == 4 -def test_iter_archive_path(tar_jsonl_path): +@pytest.mark.parametrize("archive_jsonl", ["tar_jsonl_path", "zip_jsonl_path"]) +def test_iter_archive_path(archive_jsonl, request): + archive_jsonl_path = request.getfixturevalue(archive_jsonl) dl_manager = StreamingDownloadManager() - archive_iterable = dl_manager.iter_archive(tar_jsonl_path) + archive_iterable = dl_manager.iter_archive(archive_jsonl_path) num_jsonl = 0 for num_jsonl, (path, file) in enumerate(archive_iterable, start=1): _test_jsonl(path, file) @@ -887,9 +889,11 @@ def test_iter_archive_path(tar_jsonl_path): assert num_jsonl == 2 -def test_iter_archive_file(tar_nested_jsonl_path): +@pytest.mark.parametrize("archive_nested_jsonl", ["tar_nested_jsonl_path", "zip_nested_jsonl_path"]) +def test_iter_archive_file(archive_nested_jsonl, request): + archive_nested_jsonl_path = request.getfixturevalue(archive_nested_jsonl) dl_manager = StreamingDownloadManager() - files_iterable = dl_manager.iter_archive(tar_nested_jsonl_path) + files_iterable = dl_manager.iter_archive(archive_nested_jsonl_path) num_tar, num_jsonl = 0, 0 for num_tar, (path, file) in enumerate(files_iterable, start=1): for num_jsonl, (subpath, subfile) in enumerate(dl_manager.iter_archive(file), start=1):