Skip to content

Commit

Permalink
iter_archive on zipfiles with better compression type check (#3379)
Browse files Browse the repository at this point in the history
* iter_archive on zipfiles with better compression type check

* Update src/datasets/utils/download_manager.py

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

* tests for iter_archive on streaming_dl_manager

* function names corrections

* change host and other corrections

* Refactor code

* Fix bad  merge

* Tests

* Style fixes

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: mariosasko <mariosasko777@gmail.com>
  • Loading branch information
3 people authored Jan 24, 2023
1 parent 98c9b27 commit 697b6d6
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 37 deletions.
113 changes: 109 additions & 4 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import tarfile
import time
import warnings
import zipfile
from datetime import datetime
from functools import partial
from itertools import chain
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union

from .. import config
Expand All @@ -38,6 +40,40 @@
logger = get_logger(__name__)


BASE_KNOWN_EXTENSIONS = [
"txt",
"csv",
"json",
"jsonl",
"tsv",
"conll",
"conllu",
"orig",
"parquet",
"pkl",
"pickle",
"rel",
"xml",
]
MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL = {
bytes.fromhex("504B0304"): "zip",
bytes.fromhex("504B0506"): "zip", # empty archive
bytes.fromhex("504B0708"): "zip", # spanned archive
bytes.fromhex("425A68"): "bz2",
bytes.fromhex("1F8B"): "gzip",
bytes.fromhex("FD377A585A00"): "xz",
bytes.fromhex("04224D18"): "lz4",
bytes.fromhex("28B52FFD"): "zstd",
}
MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL = {
b"Rar!": "rar",
}
MAGIC_NUMBER_MAX_LENGTH = max(
len(magic_number)
for magic_number in chain(MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL, MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL)
)


class DownloadMode(enum.Enum):
"""`Enum` for how to treat pre-existing downloads and data.
Expand Down Expand Up @@ -69,6 +105,48 @@ def help_message(self):
return "Use 'DownloadMode' instead."


def _get_path_extension(path: str) -> str:
# Get extension: train.json.gz -> gz
extension = path.split(".")[-1]
# Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz
# Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt
for symb in "?-_":
extension = extension.split(symb)[0]
return extension


def _get_extraction_protocol_with_magic_number(f) -> Optional[str]:
"""read the magic number from a file-like object and return the compression protocol"""
# Check if the file object is seekable even before reading the magic number (to avoid https://bugs.python.org/issue26440)
try:
f.seek(0)
except (AttributeError, io.UnsupportedOperation):
return None
magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH)
f.seek(0)
for i in range(MAGIC_NUMBER_MAX_LENGTH):
compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None:
return compression
compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None:
raise NotImplementedError(f"Compression protocol '{compression}' not implemented.")


def _get_extraction_protocol(path: str) -> Optional[str]:
path = str(path)
extension = _get_path_extension(path)
# TODO(mariosasko): The below check will be useful once we can preserve the original extension in the new cache layout (use the `filename` parameter of `hf_hub_download`)
if (
extension in BASE_KNOWN_EXTENSIONS
or extension in ["tgz", "tar"]
or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz"))
):
return None
with open(path, "rb") as f:
return _get_extraction_protocol_with_magic_number(f)


class _IterableFromGenerator(Iterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

Expand All @@ -84,27 +162,54 @@ def __iter__(self):
class ArchiveIterable(_IterableFromGenerator):
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""

@classmethod
def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
@staticmethod
def _iter_tar(f):
stream = tarfile.open(fileobj=f, mode="r|*")
for tarinfo in stream:
file_path = tarinfo.name
if not tarinfo.isreg():
continue
if file_path is None:
continue
if os.path.basename(file_path).startswith(".") or os.path.basename(file_path).startswith("__"):
if os.path.basename(file_path).startswith((".", "__")):
# skipping hidden files
continue
file_obj = stream.extractfile(tarinfo)
yield file_path, file_obj
stream.members = []
del stream

@staticmethod
def _iter_zip(f):
zipf = zipfile.ZipFile(f)
for member in zipf.infolist():
file_path = member.filename
if member.is_dir():
continue
if file_path is None:
continue
if os.path.basename(file_path).startswith((".", "__")):
# skipping hidden files
continue
file_obj = zipf.open(member)
yield file_path, file_obj

@classmethod
def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol_with_magic_number(f)
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def _iter_from_path(cls, urlpath: str) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol(urlpath)
with open(urlpath, "rb") as f:
yield from cls._iter_from_fileobj(f)
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def from_buf(cls, fileobj) -> "ArchiveIterable":
Expand Down
91 changes: 68 additions & 23 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tarfile
import time
import xml.dom.minidom
import zipfile
from asyncio import TimeoutError
from io import BytesIO
from itertools import chain
Expand Down Expand Up @@ -151,7 +152,7 @@ def xexists(urlpath: str, use_auth_token: Optional[Union[str, bool]] = None):
`bool`
"""

main_hop, *rest_hops = str(urlpath).split("::")
main_hop, *rest_hops = _as_str(urlpath).split("::")
if is_local_path(main_hop):
return os.path.exists(main_hop)
else:
Expand Down Expand Up @@ -381,13 +382,28 @@ def read_with_retries(*args, **kwargs):
file_obj.read = read_with_retries


def _get_path_extension(path: str) -> str:
# Get extension: https://foo.bar/train.json.gz -> gz
extension = path.split(".")[-1]
# Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz
# Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt
for symb in "?-_":
extension = extension.split(symb)[0]
return extension


def _get_extraction_protocol_with_magic_number(f) -> Optional[str]:
"""read the magic number from a file-like object and return the compression protocol"""
# Check if the file object is seekable even before reading the magic number (to avoid https://bugs.python.org/issue26440)
try:
f.seek(0)
except (AttributeError, io.UnsupportedOperation):
return None
magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH)
f.seek(0)
for i in range(MAGIC_NUMBER_MAX_LENGTH):
compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None: # TODO(QL): raise an error for .tar.gz files as in _get_extraction_protocol
if compression is not None:
return compression
compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None:
Expand All @@ -396,28 +412,17 @@ def _get_extraction_protocol_with_magic_number(f) -> Optional[str]:

def _get_extraction_protocol(urlpath: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
# get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz
urlpath = str(urlpath)
path = urlpath.split("::")[0]
# Get extension: https://foo.bar/train.json.gz -> gz
extension = path.split(".")[-1]
# Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz
# Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt
for symb in "?-_":
extension = extension.split(symb)[0]
if extension in BASE_KNOWN_EXTENSIONS:
extension = _get_path_extension(path)
if (
extension in BASE_KNOWN_EXTENSIONS
or extension in ["tgz", "tar"]
or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz"))
):
return None
elif extension in ["tgz", "tar"] or path.endswith(".tar.gz"):
raise NotImplementedError(
f"Extraction protocol for TAR archives like '{urlpath}' is not implemented in streaming mode. "
f"Please use `dl_manager.iter_archive` instead.\n\n"
f"Example usage:\n\n"
f"\turl = dl_manager.download(url)\n"
f"\ttar_archive_iterator = dl_manager.iter_archive(url)\n\n"
f"\tfor filename, file in tar_archive_iterator:\n"
f"\t\t..."
)
elif extension in COMPRESSION_EXTENSION_TO_PROTOCOL:
return COMPRESSION_EXTENSION_TO_PROTOCOL[extension]

if is_remote_url(urlpath):
# get headers and cookies for authentication on the HF Hub and for Google Drive
urlpath, kwargs = _prepare_http_url_kwargs(urlpath, use_auth_token=use_auth_token)
Expand Down Expand Up @@ -849,8 +854,8 @@ def __iter__(self):
class ArchiveIterable(_IterableFromGenerator):
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""

@classmethod
def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
@staticmethod
def _iter_tar(f):
stream = tarfile.open(fileobj=f, mode="r|*")
for tarinfo in stream:
file_path = tarinfo.name
Expand All @@ -866,12 +871,39 @@ def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
stream.members = []
del stream

@staticmethod
def _iter_zip(f):
zipf = zipfile.ZipFile(f)
for member in zipf.infolist():
file_path = member.filename
if member.is_dir():
continue
if file_path is None:
continue
if os.path.basename(file_path).startswith((".", "__")):
# skipping hidden files
continue
file_obj = zipf.open(member)
yield file_path, file_obj

@classmethod
def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol_with_magic_number(f)
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def _iter_from_urlpath(
cls, urlpath: str, use_auth_token: Optional[Union[str, bool]] = None
) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol(urlpath, use_auth_token=use_auth_token)
with xopen(urlpath, "rb", use_auth_token=use_auth_token) as f:
yield from cls._iter_from_fileobj(f)
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def from_buf(cls, fileobj) -> "ArchiveIterable":
Expand Down Expand Up @@ -995,6 +1027,19 @@ def extract(self, url_or_urls):
def _extract(self, urlpath: str) -> str:
urlpath = str(urlpath)
protocol = _get_extraction_protocol(urlpath, use_auth_token=self.download_config.use_auth_token)
# get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz
path = urlpath.split("::")[0]
extension = _get_path_extension(path)
if extension in ["tgz", "tar"] or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")):
raise NotImplementedError(
f"Extraction protocol for TAR archives like '{urlpath}' is not implemented in streaming mode. "
f"Please use `dl_manager.iter_archive` instead.\n\n"
f"Example usage:\n\n"
f"\turl = dl_manager.download(url)\n"
f"\ttar_archive_iterator = dl_manager.iter_archive(url)\n\n"
f"\tfor filename, file in tar_archive_iterator:\n"
f"\t\t..."
)
if protocol is None:
# no extraction
return urlpath
Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
return path


@pytest.fixture(scope="session")
def zip_nested_jsonl_path(zip_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(zip_jsonl_path, arcname=os.path.join("nested", os.path.basename(zip_jsonl_path)))
return path


@pytest.fixture(scope="session")
def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip"
Expand Down
12 changes: 8 additions & 4 deletions tests/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,20 @@ def _test_jsonl(path, file):
assert num_items == 4


def test_iter_archive_path(tar_jsonl_path):
@pytest.mark.parametrize("archive_jsonl", ["tar_jsonl_path", "zip_jsonl_path"])
def test_iter_archive_path(archive_jsonl, request):
archive_jsonl_path = request.getfixturevalue(archive_jsonl)
dl_manager = DownloadManager()
for num_jsonl, (path, file) in enumerate(dl_manager.iter_archive(tar_jsonl_path), start=1):
for num_jsonl, (path, file) in enumerate(dl_manager.iter_archive(archive_jsonl_path), start=1):
_test_jsonl(path, file)
assert num_jsonl == 2


def test_iter_archive_file(tar_nested_jsonl_path):
@pytest.mark.parametrize("archive_nested_jsonl", ["tar_nested_jsonl_path", "zip_nested_jsonl_path"])
def test_iter_archive_file(archive_nested_jsonl, request):
archive_nested_jsonl_path = request.getfixturevalue(archive_nested_jsonl)
dl_manager = DownloadManager()
for num_tar, (path, file) in enumerate(dl_manager.iter_archive(tar_nested_jsonl_path), start=1):
for num_tar, (path, file) in enumerate(dl_manager.iter_archive(archive_nested_jsonl_path), start=1):
for num_jsonl, (subpath, subfile) in enumerate(dl_manager.iter_archive(file), start=1):
_test_jsonl(subpath, subfile)
assert num_tar == 1
Expand Down
16 changes: 10 additions & 6 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,9 @@ def test_streaming_dl_manager_get_extraction_protocol_gg_drive(urlpath, expected
"https://foo.bar/train.tar",
],
)
def test_streaming_dl_manager_get_extraction_protocol_throws(urlpath):
def test_streaming_dl_manager_extract_throws(urlpath):
with pytest.raises(NotImplementedError):
_ = _get_extraction_protocol(urlpath)
_ = StreamingDownloadManager().extract(urlpath)


@slow # otherwise it spams Google Drive and the CI gets banned
Expand Down Expand Up @@ -873,9 +873,11 @@ def _test_jsonl(path, file):
assert num_items == 4


def test_iter_archive_path(tar_jsonl_path):
@pytest.mark.parametrize("archive_jsonl", ["tar_jsonl_path", "zip_jsonl_path"])
def test_iter_archive_path(archive_jsonl, request):
archive_jsonl_path = request.getfixturevalue(archive_jsonl)
dl_manager = StreamingDownloadManager()
archive_iterable = dl_manager.iter_archive(tar_jsonl_path)
archive_iterable = dl_manager.iter_archive(archive_jsonl_path)
num_jsonl = 0
for num_jsonl, (path, file) in enumerate(archive_iterable, start=1):
_test_jsonl(path, file)
Expand All @@ -887,9 +889,11 @@ def test_iter_archive_path(tar_jsonl_path):
assert num_jsonl == 2


def test_iter_archive_file(tar_nested_jsonl_path):
@pytest.mark.parametrize("archive_nested_jsonl", ["tar_nested_jsonl_path", "zip_nested_jsonl_path"])
def test_iter_archive_file(archive_nested_jsonl, request):
archive_nested_jsonl_path = request.getfixturevalue(archive_nested_jsonl)
dl_manager = StreamingDownloadManager()
files_iterable = dl_manager.iter_archive(tar_nested_jsonl_path)
files_iterable = dl_manager.iter_archive(archive_nested_jsonl_path)
num_tar, num_jsonl = 0, 0
for num_tar, (path, file) in enumerate(files_iterable, start=1):
for num_jsonl, (subpath, subfile) in enumerate(dl_manager.iter_archive(file), start=1):
Expand Down

0 comments on commit 697b6d6

Please sign in to comment.