From beef55e13b84a4344416c32f3fc3e85aef660201 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 17 Nov 2022 19:34:49 +0100 Subject: [PATCH 01/25] add num_shards, num_proc, storage_options to save_to_disk --- src/datasets/arrow_dataset.py | 223 +++++++++++++++++++++++++--------- 1 file changed, 165 insertions(+), 58 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 9d33301070c..cf60571a1a6 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -20,10 +20,12 @@ import itertools import json import os +import posixpath import re import shutil import sys import tempfile +import time import warnings import weakref from collections import Counter, UserDict @@ -61,7 +63,7 @@ from . import config from .arrow_reader import ArrowReader -from .arrow_writer import ArrowWriter, OptimizedTypedSequence +from .arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter from .download.download_config import DownloadConfig from .download.streaming_download_manager import xgetsize from .features import Audio, ClassLabel, Features, Image, Sequence, Value @@ -108,7 +110,7 @@ from .utils.hub import hf_hub_url from .utils.info_utils import is_small_dataset from .utils.metadata import DatasetMetadata -from .utils.py_utils import asdict, convert_file_size_to_int, unique_values +from .utils.py_utils import asdict, convert_file_size_to_int, iflatmap_unordered, unique_values from .utils.stratify import stratified_shuffle_split_generate_indices from .utils.tf_utils import minimal_tf_collate_fn from .utils.typing import PathLike @@ -1179,38 +1181,37 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Here `del` is used to del the pyarrow tables. This properly closes the files used for memory mapped tables self.__del__() - def save_to_disk(self, dataset_path: str, fs=None): + def save_to_disk( + self, + dataset_path: str, + fs="deprecated", + num_shards: Optional[int] = None, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + ): """ - Saves a dataset to a dataset directory, or in a filesystem using either :class:`~filesystems.S3FileSystem` or + Saves a dataset to a dataset directory, or in a filesystem using either :class:`~s3fs.S3FileSystem` or any implementation of ``fsspec.spec.AbstractFileSystem``. For :class:`Image` and :class:`Audio` data: - If your images and audio files are local files, then the resulting arrow file will store paths to these files. - If you want to include the bytes or your images or audio files instead, you must `read()` those files first. - This can be done by storing the "bytes" instead of the "path" of the images or audio files: - - ```python - >>> def read_image_file(example): - ... with open(example["image"].filename, "rb") as f: - ... return {"image": {"bytes": f.read()}} - >>> ds = ds.map(read_image_file) - >>> ds.save_to_disk("path/to/dataset/dir") - ``` - - ```python - >>> def read_audio_file(example): - ... with open(example["audio"]["path"], "rb") as f: - ... return {"audio": {"bytes": f.read()}} - >>> ds = ds.map(read_audio_file) - >>> ds.save_to_disk("path/to/dataset/dir") - ``` + All the Image() and Audio() data are stored in the arrow files. + If you want to store paths or urls, please use the Value("string") type. Args: dataset_path (:obj:`str`): Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset directory where the dataset will be saved to. - fs (:class:`~filesystems.S3FileSystem`, ``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): - Instance of the remote filesystem used to download the files from. + num_shards (:obj:`Union[str, int]`, optional): Number of shards to write. + Default to the same value as `num_proc` if specified. + + + num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. + + + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the file-system backend, if any. + + Example: @@ -1218,24 +1219,35 @@ def save_to_disk(self, dataset_path: str, fs=None): >>> saved_ds = ds.save_to_disk("path/to/dataset/directory") ``` """ + num_proc: int = num_proc if num_proc is not None else 1 + num_shards: int = num_shards if num_shards is not None else num_proc + if fs != "deprecated": + warnings.warn( + "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" + "You can remove this warning by passing 'storage_options=fs.storage_options' instead.", + FutureWarning, + ) + storage_options = fs.storage_options + + fs_token_paths = fsspec.get_fs_token_paths(dataset_path, storage_options=storage_options) + fs: fsspec.AbstractFileSystem = fs_token_paths[0] + is_local = not is_remote_filesystem(fs) + path_join = os.path.join if is_local else posixpath.join + if self.list_indexes(): raise ValueError("please remove all the indexes using `dataset.drop_index` before saving a dataset") - dataset = self.flatten_indices() if self._indices is not None else self + dataset = self.flatten_indices(num_proc=num_proc) if self._indices is not None else self - if is_remote_filesystem(fs): - dataset_path = extract_path_from_uri(dataset_path) - else: - fs = fsspec.filesystem("file") - cache_files_paths = [Path(cache_filename["filename"]) for cache_filename in self.cache_files] + if is_local: + Path(dataset_path).resolve().mkdir(parents=True, exist_ok=True) + parent_cache_files_paths = set( + Path(cache_filename["filename"]).resolve().parent for cache_filename in self.cache_files + ) # Check that the dataset doesn't overwrite iself. It can cause a permission error on Windows and a segfault on linux. - if Path(dataset_path, config.DATASET_ARROW_FILENAME) in cache_files_paths: + if Path(dataset_path).resolve() in parent_cache_files_paths: raise PermissionError( - f"Tried to overwrite {Path(dataset_path, config.DATASET_ARROW_FILENAME)} but a dataset can't overwrite itself." - ) - if Path(dataset_path, config.DATASET_INDICES_FILENAME) in cache_files_paths: - raise PermissionError( - f"Tried to overwrite {Path(dataset_path, config.DATASET_INDICES_FILENAME)} but a dataset can't overwrite itself." + f"Tried to overwrite {Path(dataset_path).resolve()} but a dataset can't overwrite itself." ) # Get json serializable state @@ -1246,15 +1258,13 @@ def save_to_disk(self, dataset_path: str, fs=None): "_format_columns", "_format_kwargs", "_format_type", - "_indexes", "_output_all_columns", ] } - - split = dataset.__dict__["_split"] - state["_split"] = str(split) if split is not None else split - - state["_data_files"] = [{"filename": config.DATASET_ARROW_FILENAME}] + state["_split"] = str(dataset.split) if dataset.split is not None else dataset.split + state["_data_files"] = [ + {"filename": f"data-{shard_idx:05d}-of-{num_shards:05d}.arrow"} for shard_idx in range(num_shards) + ] for k in state["_format_kwargs"].keys(): try: json.dumps(state["_format_kwargs"][k]) @@ -1262,27 +1272,106 @@ def save_to_disk(self, dataset_path: str, fs=None): raise TypeError( str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't." ) from None - # Get json serializable dataset info dataset_info = asdict(dataset._info) - # Save dataset + state + info - fs.makedirs(dataset_path, exist_ok=True) - with fs.open(Path(dataset_path, config.DATASET_ARROW_FILENAME).as_posix(), "wb") as dataset_file: - with ArrowWriter(stream=dataset_file) as writer: - writer.write_table(dataset._data.table) - writer.finalize() + shards_done = 0 + pbar = logging.tqdm( + disable=not logging.is_progress_bar_enabled(), + unit=" examples", + total=len(dataset), + leave=False, + desc=f"Saving the dataset ({shards_done}/{num_shards} shards)", + ) + args_per_job = ( + { + "job_id": shard_idx, + "shard": dataset.shard(num_shards=num_shards, index=shard_idx, contiguous=True), + "fpath": path_join(dataset_path, f"data-{shard_idx:05d}-of-{num_shards:05d}.arrow"), + "storage_options": storage_options, + } + for shard_idx in range(num_shards) + ) + shard_lengths = [None] * num_shards + shard_sizes = [None] * num_shards + if num_proc > 1: + with Pool(num_proc) as pool: + for job_id, done, content in iflatmap_unordered(pool, Dataset._save_to_disk_single, args_per_job): + if done: + shards_done += 1 + pbar.set_description(f"Saving dataset ({shards_done}/{num_shards} shards)") + logger.debug(f"Finished writing shard number {job_id} of {num_shards}.") + shard_lengths[job_id], shard_sizes[job_id] = content + else: + pbar.update(content) + else: + for args in args_per_job: + for job_id, done, content in Dataset._save_to_disk_single(args): + if done: + shards_done += 1 + pbar.set_description(f"Saving dataset ({shards_done}/{num_shards} shards)") + logger.debug(f"Finished writing shard number {job_id} of {num_shards}.") + shard_lengths[job_id], shard_sizes[job_id] = content + else: + pbar.update(content) with fs.open( - Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8" + path_join(dataset_path, config.DATASET_STATE_JSON_FILENAME), "w", encoding="utf-8" ) as state_file: json.dump(state, state_file, indent=2, sort_keys=True) with fs.open( - Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8" + path_join(dataset_path, config.DATASET_INFO_FILENAME), "w", encoding="utf-8" ) as dataset_info_file: # Sort only the first level of keys, or we might shuffle fields of nested features if we use sort_keys=True sorted_keys_dataset_info = {key: dataset_info[key] for key in sorted(dataset_info)} json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2) - logger.info(f"Dataset saved in {dataset_path}") + + @staticmethod + def _save_to_disk_single(arg): + job_id: Dataset = arg["job_id"] + shard: Dataset = arg["shard"] + fpath: str = arg["fpath"] + storage_options: Optional[dict] = arg["storage_options"] + refresh_rate = 0.05 # 20 progress updates per sec + batch_size = config.DEFAULT_MAX_BATCH_SIZE + + if shard._indices is not None: + raise ValueError( + "`_save_to_disk_single` only support shards with flattened indices. " + "Please call ds.flatten_indices() before saving to disk." + ) + + num_examples_progress_update = 0 + writer = ArrowWriter( + features=shard.features, + path=fpath, + storage_options=storage_options, + embed_local_files=True, + ) + try: + _time = time.time() + if config.PYARROW_VERSION.major >= 8: + for pa_table in table_iter(shard.data.table, batch_size=batch_size): + writer.write_table(pa_table) + num_examples_progress_update += len(pa_table) + if time.time() > _time + refresh_rate: + _time = time.time() + yield job_id, False, num_examples_progress_update + num_examples_progress_update = 0 + else: + for i in range(0, shard.num_rows, batch_size): + pa_table = shard.data.slice(i, batch_size) + writer.write_table(pa_table) + num_examples_progress_update += len(pa_table) + if time.time() > _time + refresh_rate: + _time = time.time() + yield job_id, False, num_examples_progress_update + num_examples_progress_update = 0 + finally: + yield job_id, False, num_examples_progress_update + num_examples, num_bytes = writer.finalize() + writer.close() + + yield job_id, True, (num_examples, num_bytes) @staticmethod def _build_local_temp_path(uri_or_path: str) -> Path: @@ -1302,21 +1391,27 @@ def _build_local_temp_path(uri_or_path: str) -> Path: return Path(tmp_dir, src_dataset_path.relative_to(src_dataset_path.anchor)) @staticmethod - def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = None) -> "Dataset": + def load_from_disk( + dataset_path: str, + fs="deprecated", + keep_in_memory: Optional[bool] = None, + storage_options: Optional[dict] = None, + ) -> "Dataset": """ - Loads a dataset that was previously saved using :meth:`save_to_disk` from a dataset directory, or from a + Loads an Arrow dataset that was previously saved using :meth:`save_to_disk` from a dataset directory, or from a filesystem using either :class:`~filesystems.S3FileSystem` or any implementation of ``fsspec.spec.AbstractFileSystem``. Args: dataset_path (:obj:`str`): Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`) of the dataset directory where the dataset will be loaded from. - fs (:class:`~filesystems.S3FileSystem`, ``fsspec.spec.AbstractFileSystem``, optional, default ``None``): - Instance of the remote filesystem used to download the files from. keep_in_memory (:obj:`bool`, default ``None``): Whether to copy the dataset in-memory. If `None`, the dataset will not be copied in-memory unless explicitly enabled by setting `datasets.config.IN_MEMORY_MAX_SIZE` to nonzero. See more details in the :ref:`load_dataset_enhancing_performance` section. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the file-system backend, if any. + + Returns: :class:`Dataset` or :class:`DatasetDict`: @@ -1329,8 +1424,17 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = >>> ds = load_from_disk("path/to/dataset/directory") ``` """ + if fs != "deprecated": + warnings.warn( + "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" + "You can remove this warning by passing 'storage_options=fs.storage_options' instead.", + FutureWarning, + ) + storage_options = fs.storage_options + + fs_token_paths = fsspec.get_fs_token_paths(dataset_path, storage_options=storage_options) + fs: fsspec.AbstractFileSystem = fs_token_paths[0] # copies file from filesystem if it is remote filesystem to local filesystem and modifies dataset_path to temp directory containing local copies - fs = fsspec.filesystem("file") if fs is None else fs dataset_dict_json_path = Path(dataset_path, config.DATASETDICT_JSON_FILENAME).as_posix() dataset_info_path = Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix() if not fs.isfile(dataset_info_path) and fs.isfile(dataset_dict_json_path): @@ -3145,6 +3249,7 @@ def flatten_indices( writer_batch_size: Optional[int] = 1000, features: Optional[Features] = None, disable_nullable: bool = False, + num_proc: Optional[int] = None, new_fingerprint: Optional[str] = None, ) -> "Dataset": """Create and cache a new Dataset by flattening the indices mapping. @@ -3159,6 +3264,7 @@ def flatten_indices( features (`Optional[datasets.Features]`, default `None`): Use a specific Features to store the cache file instead of the automatically generated one. disable_nullable (:obj:`bool`, default `False`): Allow null values in the table. + num_proc (:obj:`int`, optional, default `None`): Max number of processes when generating cache. Already cached shards are loaded sequentially new_fingerprint (:obj:`str`, optional, default `None`): The new fingerprint of the dataset after transform. If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments """ @@ -3172,6 +3278,7 @@ def flatten_indices( disable_nullable=disable_nullable, new_fingerprint=new_fingerprint, desc="Flattening the indices", + num_proc=num_proc, ) def _new_dataset_with_indices( From d1b7fb5c8b71b098be67fe94de73ed1365167345 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 17 Nov 2022 19:34:57 +0100 Subject: [PATCH 02/25] minor --- src/datasets/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 6ac6ad75570..7e8f08657e5 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -643,7 +643,7 @@ def download_and_prepare( Multiprocessing is disabled by default. - storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the file-system backend, if any. **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. From 2d59bb64aef996cdca77ff6b7acdb870b616ad2c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 18 Nov 2022 19:45:32 +0100 Subject: [PATCH 03/25] add tests --- tests/test_arrow_dataset.py | 63 +++++++++++++++++++++++++------------ tests/test_dataset_dict.py | 22 ------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ba1b5767c7d..ab18175c1f8 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -53,7 +53,6 @@ assert_arrow_memory_increases, require_jax, require_pil, - require_s3, require_sqlalchemy, require_tf, require_torch, @@ -266,6 +265,7 @@ def test_dummy_dataset_serialize(self, in_memory): self.assertDictEqual(dset.features, Features({"filename": Value("string")})) self.assertEqual(dset[0]["filename"], "my_name-train_0") self.assertEqual(dset["filename"][0], "my_name-train_0") + expected = dset.to_dict() with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset: dataset_path = os.path.join(tmp_dir, "my_dataset") # abs path @@ -302,6 +302,36 @@ def test_dummy_dataset_serialize(self, in_memory): self.assertDictEqual(dset[0]["nested"], {"a": 1, "c": 100, "x": 10}) self.assertDictEqual(dset["nested"][0], {"a": 1, "c": 100, "x": 10}) + with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset: + with assert_arrow_memory_doesnt_increase(): + dset.save_to_disk(dataset_path, num_shards=4) + + with Dataset.load_from_disk(dataset_path) as dset: + self.assertEqual(len(dset), 10) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset.to_dict(), expected) + self.assertEqual(len(dset.cache_files), 4) + + with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset: + with assert_arrow_memory_doesnt_increase(): + dset.save_to_disk(dataset_path, num_proc=2) + + with Dataset.load_from_disk(dataset_path) as dset: + self.assertEqual(len(dset), 10) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset.to_dict(), expected) + self.assertEqual(len(dset.cache_files), 2) + + with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset: + with assert_arrow_memory_doesnt_increase(): + dset.save_to_disk(dataset_path, num_shards=7, num_proc=2) + + with Dataset.load_from_disk(dataset_path) as dset: + self.assertEqual(len(dset), 10) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset.to_dict(), expected) + self.assertEqual(len(dset.cache_files), 7) + def test_dummy_dataset_load_from_disk(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: @@ -3526,25 +3556,18 @@ def test_pickle_dataset_after_transforming_the_table(in_memory, method_and_param assert dataset._data.table == reloaded_dataset._data.table -@pytest.mark.skipif( - os.name in ["nt", "posix"] and (os.getenv("CIRCLECI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"), - reason='On Windows CircleCI or GitHub Actions, it raises botocore.exceptions.EndpointConnectionError: Could not connect to the endpoint URL: "http://127.0.0.1:5555/test"', -) # TODO: find what's wrong with CircleCI / GitHub Actions -@require_s3 -@pytest.mark.integration -def test_dummy_dataset_serialize_s3(s3, dataset, s3_test_bucket_name): - mock_bucket = s3_test_bucket_name - dataset_path = f"s3://{mock_bucket}/my_dataset" - features = dataset.features - dataset.save_to_disk(dataset_path, s3) - dataset = dataset.load_from_disk(dataset_path, s3) - assert os.path.isfile(dataset.cache_files[0]["filename"]) - - assert len(dataset) == 10 - assert len(dataset.shuffle()) == 10 - assert dataset.features == features - assert dataset[0]["id"] == 0 - assert dataset["id"][0] == 0 +def test_dummy_dataset_serialize_fs(dataset, mock_fsspec, tmp_path_factory): + dataset_path = "mock://my_dataset" + storage_options = { + "local_root_dir": tmp_path_factory.mktemp("test_dummy_dataset_serialize_fs"), + "auto_mkdir": True + } + dataset.save_to_disk(dataset_path, storage_options=storage_options) + reloaded = dataset.load_from_disk(dataset_path, storage_options=storage_options) + assert os.path.isfile(reloaded.cache_files[0]["filename"]) + assert len(reloaded) == len(dataset) + assert reloaded.features == dataset.features + assert reloaded.to_dict() == dataset.to_dict() @pytest.mark.parametrize( diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 175b5785e71..d2d1a0631f5 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -15,7 +15,6 @@ from .utils import ( assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, - require_s3, require_tf, require_torch, ) @@ -665,24 +664,3 @@ def test_datasetdict_from_text_split(split, text_path, tmp_path): dataset = DatasetDict.from_text(path, cache_dir=cache_dir) _check_text_datasetdict(dataset, expected_features, splits=list(path.keys())) assert all(dataset[split].split == split for split in path.keys()) - - -@pytest.mark.skipif( - os.name in ["nt", "posix"] and (os.getenv("CIRCLECI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"), - reason='On Windows CircleCI or GitHub Actions, it raises botocore.exceptions.EndpointConnectionError: Could not connect to the endpoint URL: "http://127.0.0.1:5555/test"', -) # TODO: find what's wrong with CircleCI / GitHub Actions -@require_s3 -@pytest.mark.integration -def test_dummy_dataset_serialize_s3(s3, dataset, s3_test_bucket_name): - dsets = DatasetDict({"train": dataset, "test": dataset.select(range(2))}) - mock_bucket = s3_test_bucket_name - dataset_path = f"s3://{mock_bucket}/datasets/dict" - column_names = dsets["train"].column_names - lengths = [len(dset) for dset in dsets.values()] - dataset.save_to_disk(dataset_path, s3) - dataset = dataset.load_from_disk(dataset_path, s3) - - assert sorted(dsets) == ["test", "train"] - assert [len(dset) for dset in dsets.values()] == lengths - assert dsets["train"].column_names == column_names - assert dsets["test"].column_names == column_names From 2e270dc83a3ce0dbb81f20d57bebc8fdddc62b58 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 18 Nov 2022 19:46:01 +0100 Subject: [PATCH 04/25] remove old s3fs integreation tests --- setup.py | 14 ++------ tests/conftest.py | 2 +- tests/fixtures/fsspec.py | 2 +- tests/fixtures/s3.py | 74 ---------------------------------------- tests/test_filesystem.py | 25 ++------------ tests/utils.py | 16 --------- 6 files changed, 6 insertions(+), 127 deletions(-) delete mode 100644 tests/fixtures/s3.py diff --git a/setup.py b/setup.py index ad0f0af82ae..e4349dd34ec 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ "importlib_metadata;python_version<'3.8'", # to save datasets locally or on any filesystem # minimum 2021.11.1 so that BlockSizeError is fixed: see https://github.com/fsspec/filesystem_spec/pull/830 - "fsspec[http]>=2021.11.1", # aligned s3fs with this + "fsspec[http]>=2021.11.1", # for data streaming via http "aiohttp", # To get datasets from the Datasets Hub on huggingface.co @@ -122,13 +122,8 @@ # optional dependencies "apache-beam>=2.26.0", "elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElastictSearch() - "aiobotocore>=2.0.1", # required by s3fs>=2021.11.1 - "boto3>=1.19.8", # to be compatible with aiobotocore>=2.0.1 - both have strong dependencies on botocore - "botocore>=1.22.8", # to be compatible with aiobotocore and boto3 "faiss-cpu>=1.6.4", - "fsspec[s3]", "lz4", - "moto[s3,server]==2.0.4", "py7zr", "rarfile>=4.0", "s3fs>=2021.11.1", # aligned with fsspec[http]>=2021.11.1 @@ -181,12 +176,7 @@ ], "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], "torch": ["torch"], - "s3": [ - "fsspec", - "boto3", - "botocore", - "s3fs", - ], + "s3": ["s3fs"], "streaming": [], # for backward compatibility "dev": TESTS_REQUIRE + QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/conftest.py b/tests/conftest.py index beb32b7119e..c6985a6ec78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ # Import fixture modules as plugins -pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.s3", "tests.fixtures.fsspec"] +pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.fsspec"] def pytest_collection_modifyitems(config, items): diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index 7a301116ea8..f6dcb1c4103 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -81,4 +81,4 @@ def mock_fsspec(monkeypatch): @pytest.fixture def mockfs(tmp_path_factory, mock_fsspec): local_fs_dir = tmp_path_factory.mktemp("mockfs") - return MockFileSystem(local_root_dir=local_fs_dir) + return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True) diff --git a/tests/fixtures/s3.py b/tests/fixtures/s3.py deleted file mode 100644 index bbd385bb28d..00000000000 --- a/tests/fixtures/s3.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import time - -import pytest -import requests - - -# From: https://github.com/dask/s3fs/blob/ffe3a5293524869df56e74973af0d2c204ae9cbf/s3fs/tests/test_s3fs.py#L25-L141 - -S3_TEST_BUCKET_NAME = "test" -s3_port = 5555 -s3_endpoint_uri = f"http://127.0.0.1:{s3_port}/" - -S3_FAKE_ENV_VARS = { - "AWS_ACCESS_KEY_ID": "fake_access_key", - "AWS_SECRET_ACCESS_KEY": "fake_secret_key", - "AWS_SECURITY_TOKEN": "fake_secrurity_token", - "AWS_SESSION_TOKEN": "fake_session_token", -} - - -@pytest.fixture(scope="session") -def s3_test_bucket_name(): - return S3_TEST_BUCKET_NAME - - -@pytest.fixture() -def s3_base(): - # writable local S3 system - import shlex - import subprocess - - # Mocked AWS Credentials for moto. - old_environ = os.environ.copy() - os.environ.update(S3_FAKE_ENV_VARS) - - proc = subprocess.Popen(shlex.split(f"moto_server s3 -p {s3_port}")) - - timeout = 5 - while timeout > 0: - try: - r = requests.get(s3_endpoint_uri) - if r.ok: - break - except: # noqa - pass - timeout -= 0.1 - time.sleep(0.1) - yield - proc.terminate() - proc.wait() - os.environ.clear() - os.environ.update(old_environ) - - -def get_boto3_client(): - from botocore.session import Session - - # NB: we use the sync botocore client for setup - session = Session() - return session.create_client("s3", endpoint_url=s3_endpoint_uri) - - -@pytest.fixture() -def s3(s3_base, s3_test_bucket_name): - client = get_boto3_client() - client.create_bucket(Bucket=s3_test_bucket_name, ACL="public-read") - - from s3fs.core import S3FileSystem - - S3FileSystem.clear_instance_cache() - s3 = S3FileSystem(anon=False, client_kwargs={"endpoint_url": s3_endpoint_uri}) - s3.invalidate_cache() - yield s3 diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index faafe79dd54..72fc56da6e0 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -1,9 +1,7 @@ import os -import boto3 import fsspec import pytest -from moto import mock_s3 from datasets.filesystems import ( COMPRESSION_FILESYSTEMS, @@ -17,29 +15,10 @@ from .utils import require_lz4, require_zstandard -@pytest.fixture(scope="function") -def aws_credentials(): - """Mocked AWS Credentials for moto.""" - os.environ["AWS_ACCESS_KEY_ID"] = "fake_access_key" - os.environ["AWS_SECRET_ACCESS_KEY"] = "fake_secret_key" - os.environ["AWS_SECURITY_TOKEN"] = "fake_secrurity_token" - os.environ["AWS_SESSION_TOKEN"] = "fake_session_token" - - -@pytest.fixture(scope="function") -def s3(aws_credentials): - with mock_s3(): - yield boto3.client("s3", region_name="us-east-1") - - -def test_extract_path_from_uri(s3): - - mock_bucket = "moto-mock-s3-bucket" - # We need to create the bucket since this is all in Moto's 'virtual' AWS account - s3.create_bucket(Bucket=mock_bucket) +def test_extract_path_from_uri(): + mock_bucket = "mock-s3-bucket" dataset_path = f"s3://{mock_bucket}" - dataset_path = extract_path_from_uri(dataset_path) assert dataset_path.startswith("s3://") is False diff --git a/tests/utils.py b/tests/utils.py index 212a8ed84f4..397b2ed261d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -201,22 +201,6 @@ def require_transformers(test_case): return test_case -def require_s3(test_case): - """ - Decorator marking a test that requires s3fs and moto to mock s3. - - These tests are skipped when they aren't installed. - - """ - try: - import moto # noqa F401 - import s3fs # noqa F401 - except ImportError: - return unittest.skip("test requires s3fs and moto")(test_case) - else: - return test_case - - def require_spacy(test_case): """ Decorator marking a test that requires spacy. From 532ae18e483fc4fe0304847858c276861ee5433c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 18 Nov 2022 19:46:29 +0100 Subject: [PATCH 05/25] style --- src/datasets/arrow_dataset.py | 4 +--- tests/test_arrow_dataset.py | 2 +- tests/test_dataset_dict.py | 7 +------ 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index cf60571a1a6..3649047fdde 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1314,9 +1314,7 @@ def save_to_disk( shard_lengths[job_id], shard_sizes[job_id] = content else: pbar.update(content) - with fs.open( - path_join(dataset_path, config.DATASET_STATE_JSON_FILENAME), "w", encoding="utf-8" - ) as state_file: + with fs.open(path_join(dataset_path, config.DATASET_STATE_JSON_FILENAME), "w", encoding="utf-8") as state_file: json.dump(state, state_file, indent=2, sort_keys=True) with fs.open( path_join(dataset_path, config.DATASET_INFO_FILENAME), "w", encoding="utf-8" diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ab18175c1f8..6283a2bcdf8 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3560,7 +3560,7 @@ def test_dummy_dataset_serialize_fs(dataset, mock_fsspec, tmp_path_factory): dataset_path = "mock://my_dataset" storage_options = { "local_root_dir": tmp_path_factory.mktemp("test_dummy_dataset_serialize_fs"), - "auto_mkdir": True + "auto_mkdir": True, } dataset.save_to_disk(dataset_path, storage_options=storage_options) reloaded = dataset.load_from_disk(dataset_path, storage_options=storage_options) diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index d2d1a0631f5..dddc3f99894 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -12,12 +12,7 @@ from datasets.features import ClassLabel, Features, Sequence, Value from datasets.splits import NamedSplit -from .utils import ( - assert_arrow_memory_doesnt_increase, - assert_arrow_memory_increases, - require_tf, - require_torch, -) +from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_tf, require_torch class DatasetDictTest(TestCase): From f548e018399ccb65dac0bf259002014a433ecb6b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 18 Nov 2022 19:53:16 +0100 Subject: [PATCH 06/25] style --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 3649047fdde..5901c3f7656 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -63,7 +63,7 @@ from . import config from .arrow_reader import ArrowReader -from .arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter +from .arrow_writer import ArrowWriter, OptimizedTypedSequence from .download.download_config import DownloadConfig from .download.streaming_download_manager import xgetsize from .features import Audio, ClassLabel, Features, Image, Sequence, Value From dcd636371f2ec2d7c79e2c28808149287381d76f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 17:33:36 +0100 Subject: [PATCH 07/25] Update DatasetDict.save_to_disk --- src/datasets/arrow_dataset.py | 6 +-- src/datasets/dataset_dict.py | 87 ++++++++++++++++++++++------------- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5901c3f7656..211c0a204b0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1183,7 +1183,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def save_to_disk( self, - dataset_path: str, + dataset_path: PathLike, fs="deprecated", num_shards: Optional[int] = None, num_proc: Optional[int] = None, @@ -1199,9 +1199,9 @@ def save_to_disk( If you want to store paths or urls, please use the Value("string") type. Args: - dataset_path (:obj:`str`): Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) + dataset_path (``PathLike``): Path (e.g. `path/to/dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset directory where the dataset will be saved to. - num_shards (:obj:`Union[str, int]`, optional): Number of shards to write. + num_shards (:obj:`int`, optional): Number of shards to write. Default to the same value as `num_proc` if specified. diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index a6f97d9f5cc..252c554a105 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -2,6 +2,7 @@ import copy import json import os +import posixpath import re import warnings from io import BytesIO @@ -1032,56 +1033,78 @@ def shuffle( } ) - def save_to_disk(self, dataset_dict_path: str, fs=None): + def save_to_disk( + self, + dataset_dict_path: PathLike, + fs="deprecated", + num_shards: Optional[Dict[str, int]] = None, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + ): """ Saves a dataset dict to a filesystem using either :class:`~filesystems.S3FileSystem` or ``fsspec.spec.AbstractFileSystem``. For :class:`Image` and :class:`Audio` data: - If your images and audio files are local files, then the resulting arrow file will store paths to these files. - If you want to include the bytes or your images or audio files instead, you must `read()` those files first. - This can be done by storing the "bytes" instead of the "path" of the images or audio files: - - ```python - >>> def read_image_file(example): - ... with open(example["image"].filename, "rb") as f: - ... return {"image": {"bytes": f.read()}} - >>> ds = ds.map(read_image_file) - >>> ds.save_to_disk("path/to/dataset/dir") - ``` - - ```python - >>> def read_audio_file(example): - ... with open(example["audio"]["path"], "rb") as f: - ... return {"audio": {"bytes": f.read()}} - >>> ds = ds.map(read_audio_file) - >>> ds.save_to_disk("path/to/dataset/dir") - ``` + All the Image() and Audio() data are stored in the arrow files. + If you want to store paths or urls, please use the Value("string") type. Args: - dataset_dict_path (``str``): Path (e.g. `dataset/train`) or remote URI + dataset_dict_path (``PathLike``): Path (e.g. `path/to/dataset/`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be saved to. - fs (:class:`~filesystems.S3FileSystem`, ``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): - Instance of the remote filesystem used to download the files from. + num_shards (:obj:`Dict[str, int]`, optional): Number of shards to write. + You need to provide the number of shards for each dataset in the dataset dictionary. + Default to the same value as `num_proc` if specified. + + + num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. + + + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the file-system backend, if any. + + + """ - if is_remote_filesystem(fs): - dest_dataset_dict_path = extract_path_from_uri(dataset_dict_path) - else: - fs = fsspec.filesystem("file") - dest_dataset_dict_path = dataset_dict_path - os.makedirs(dest_dataset_dict_path, exist_ok=True) + if fs != "deprecated": + warnings.warn( + "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" + "You can remove this warning by passing 'storage_options=fs.storage_options' instead.", + FutureWarning, + ) + storage_options = fs.storage_options + + fs_token_paths = fsspec.get_fs_token_paths(dataset_dict_path, storage_options=storage_options) + fs: fsspec.AbstractFileSystem = fs_token_paths[0] + is_local = not is_remote_filesystem(fs) + path_join = os.path.join if is_local else posixpath.join + + if num_shards is None: + num_shards = {k: None for k in self} + elif not isinstance(num_shards, dict): + raise ValueError( + "Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}" + ) + + if is_local: + Path(dataset_dict_path).resolve().mkdir(parents=True, exist_ok=True) json.dump( {"splits": list(self)}, - fs.open(Path(dest_dataset_dict_path, config.DATASETDICT_JSON_FILENAME).as_posix(), "w", encoding="utf-8"), + fs.open(path_join(dataset_dict_path, config.DATASETDICT_JSON_FILENAME), "w", encoding="utf-8"), ) for k, dataset in self.items(): - dataset.save_to_disk(Path(dest_dataset_dict_path, k).as_posix(), fs) + dataset.save_to_disk( + path_join(dataset_dict_path, k), + num_shards=num_shards.get(k), + num_proc=num_proc, + storage_options=storage_options, + ) @staticmethod - def load_from_disk(dataset_dict_path: str, fs=None, keep_in_memory: Optional[bool] = None) -> "DatasetDict": + def load_from_disk(dataset_dict_path: PathLike, fs=None, keep_in_memory: Optional[bool] = None) -> "DatasetDict": """ Load a dataset that was previously saved using :meth:`save_to_disk` from a filesystem using either :class:`~filesystems.S3FileSystem` or ``fsspec.spec.AbstractFileSystem``. From 26a3e15f538331c0755b8b97c05a8d0671baefd8 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 18:09:43 +0100 Subject: [PATCH 08/25] test dataset dict --- tests/test_dataset_dict.py | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index dddc3f99894..fb9b7b9508a 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -388,6 +388,30 @@ def test_serialization(self): self.assertListEqual(reloaded_dsets["train"].column_names, ["filename"]) del dsets, reloaded_dsets + dsets = self._create_dummy_dataset_dict() + dsets.save_to_disk(tmp_dir, num_shards={"train": 3, "test": 2}) + reloaded_dsets = DatasetDict.load_from_disk(tmp_dir) + self.assertListEqual(sorted(reloaded_dsets), ["test", "train"]) + self.assertEqual(len(reloaded_dsets["train"]), 30) + self.assertListEqual(reloaded_dsets["train"].column_names, ["filename"]) + self.assertEqual(len(reloaded_dsets["train"].cache_files), 3) + self.assertEqual(len(reloaded_dsets["test"]), 30) + self.assertListEqual(reloaded_dsets["test"].column_names, ["filename"]) + self.assertEqual(len(reloaded_dsets["test"].cache_files), 2) + del reloaded_dsets + + dsets = self._create_dummy_dataset_dict() + dsets.save_to_disk(tmp_dir, num_proc=2) + reloaded_dsets = DatasetDict.load_from_disk(tmp_dir) + self.assertListEqual(sorted(reloaded_dsets), ["test", "train"]) + self.assertEqual(len(reloaded_dsets["train"]), 30) + self.assertListEqual(reloaded_dsets["train"].column_names, ["filename"]) + self.assertEqual(len(reloaded_dsets["train"].cache_files), 2) + self.assertEqual(len(reloaded_dsets["test"]), 30) + self.assertListEqual(reloaded_dsets["test"].column_names, ["filename"]) + self.assertEqual(len(reloaded_dsets["test"].cache_files), 2) + del reloaded_dsets + def test_load_from_disk(self): with tempfile.TemporaryDirectory() as tmp_dir: dsets = self._create_dummy_dataset_dict() @@ -441,6 +465,24 @@ def test_align_labels_with_mapping(self): self.assertListEqual(test_expected_label_names, test_aligned_label_names) +def test_dummy_datasetdict_serialize_fs(mockfs): + dataset_dict = DatasetDict( + { + "train": Dataset.from_dict({"a": range(30)}), + "test": Dataset.from_dict({"a": range(10)}), + } + ) + dataset_path = "mock://my_dataset" + dataset_dict.save_to_disk(dataset_path, storage_options=mockfs.storage_options) + assert mockfs.isdir(dataset_path) + assert mockfs.glob(dataset_path + "/*") + reloaded = dataset_dict.load_from_disk(dataset_path, storage_options=mockfs.storage_options) + assert list(reloaded) == list(dataset_dict) + for k in dataset_dict: + assert reloaded[k].features == dataset_dict[k].features + assert reloaded[k].to_dict() == dataset_dict[k].to_dict() + + def _check_csv_datasetdict(dataset_dict, expected_features, splits=("train",)): assert isinstance(dataset_dict, DatasetDict) for split in splits: From 291a883952b8767e141df8ad6a83990c4fa42c1f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 18:09:51 +0100 Subject: [PATCH 09/25] update dataset dict load_from_disk --- src/datasets/dataset_dict.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 252c554a105..c9ac0eff820 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1051,7 +1051,7 @@ def save_to_disk( If you want to store paths or urls, please use the Value("string") type. Args: - dataset_dict_path (``PathLike``): Path (e.g. `path/to/dataset/`) or remote URI + dataset_dict_path (``PathLike``): Path (e.g. `path/to/dataset`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be saved to. num_shards (:obj:`Dict[str, int]`, optional): Number of shards to write. @@ -1104,21 +1104,27 @@ def save_to_disk( ) @staticmethod - def load_from_disk(dataset_dict_path: PathLike, fs=None, keep_in_memory: Optional[bool] = None) -> "DatasetDict": + def load_from_disk( + dataset_dict_path: PathLike, + fs="deprecated", + keep_in_memory: Optional[bool] = None, + storage_options: Optional[dict] = None, + ) -> "DatasetDict": """ Load a dataset that was previously saved using :meth:`save_to_disk` from a filesystem using either :class:`~filesystems.S3FileSystem` or ``fsspec.spec.AbstractFileSystem``. Args: - dataset_dict_path (:obj:`str`): Path (e.g. ``"dataset/train"``) or remote URI (e.g. + dataset_dict_path (``PathLike``): Path (e.g. ``"path/to/dataset"``) or remote URI (e.g. ``"s3//my-bucket/dataset/train"``) of the dataset dict directory where the dataset dict will be loaded from. - fs (:class:`~filesystems.S3FileSystem` or ``fsspec.spec.AbstractFileSystem``, optional, default ``None``): - Instance of the remote filesystem used to download the files from. keep_in_memory (:obj:`bool`, default ``None``): Whether to copy the dataset in-memory. If `None`, the dataset will not be copied in-memory unless explicitly enabled by setting `datasets.config.IN_MEMORY_MAX_SIZE` to nonzero. See more details in the :ref:`load_dataset_enhancing_performance` section. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the file-system backend, if any. + + Returns: :class:`DatasetDict` @@ -1129,6 +1135,17 @@ def load_from_disk(dataset_dict_path: PathLike, fs=None, keep_in_memory: Optiona >>> ds = load_from_disk('path/to/dataset/directory') ``` """ + if fs != "deprecated": + warnings.warn( + "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" + "You can remove this warning by passing 'storage_options=fs.storage_options' instead.", + FutureWarning, + ) + storage_options = fs.storage_options + + fs_token_paths = fsspec.get_fs_token_paths(dataset_dict_path, storage_options=storage_options) + fs: fsspec.AbstractFileSystem = fs_token_paths[0] + dataset_dict = DatasetDict() if is_remote_filesystem(fs): dest_dataset_dict_path = extract_path_from_uri(dataset_dict_path) @@ -1147,7 +1164,9 @@ def load_from_disk(dataset_dict_path: PathLike, fs=None, keep_in_memory: Optiona if is_remote_filesystem(fs) else Path(dest_dataset_dict_path, k).as_posix() ) - dataset_dict[k] = Dataset.load_from_disk(dataset_dict_split_path, fs, keep_in_memory=keep_in_memory) + dataset_dict[k] = Dataset.load_from_disk( + dataset_dict_split_path, keep_in_memory=keep_in_memory, storage_options=storage_options + ) return dataset_dict @staticmethod From c55028b4202ce737bad0d53ad64d30b3cbbcda3f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 18:09:58 +0100 Subject: [PATCH 10/25] minor --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 211c0a204b0..5cd802d0b04 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1401,7 +1401,7 @@ def load_from_disk( ``fsspec.spec.AbstractFileSystem``. Args: - dataset_path (:obj:`str`): Path (e.g. `"dataset/train"`) or remote URI (e.g. + dataset_path (:obj:`str`): Path (e.g. `"path/to/dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`) of the dataset directory where the dataset will be loaded from. keep_in_memory (:obj:`bool`, default ``None``): Whether to copy the dataset in-memory. If `None`, the dataset will not be copied in-memory unless explicitly enabled by setting From f122f6dc5fc67b5dca64ce2daa00c4e4effcc04d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 18:10:05 +0100 Subject: [PATCH 11/25] update test --- tests/test_arrow_dataset.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 6283a2bcdf8..9cd218ca24e 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3556,15 +3556,12 @@ def test_pickle_dataset_after_transforming_the_table(in_memory, method_and_param assert dataset._data.table == reloaded_dataset._data.table -def test_dummy_dataset_serialize_fs(dataset, mock_fsspec, tmp_path_factory): +def test_dummy_dataset_serialize_fs(dataset, mockfs): dataset_path = "mock://my_dataset" - storage_options = { - "local_root_dir": tmp_path_factory.mktemp("test_dummy_dataset_serialize_fs"), - "auto_mkdir": True, - } - dataset.save_to_disk(dataset_path, storage_options=storage_options) - reloaded = dataset.load_from_disk(dataset_path, storage_options=storage_options) - assert os.path.isfile(reloaded.cache_files[0]["filename"]) + dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options) + assert mockfs.isdir(dataset_path) + assert mockfs.glob(dataset_path + "/*") + reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options) assert len(reloaded) == len(dataset) assert reloaded.features == dataset.features assert reloaded.to_dict() == dataset.to_dict() From 8305f8cc0f414397b5a5da80cbca0411505902cc Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 21 Nov 2022 19:02:48 +0100 Subject: [PATCH 12/25] update docs --- docs/source/filesystems.mdx | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index b92b8e6f243..c219099acef 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -19,10 +19,10 @@ Here are examples for S3, Google Cloud Storage and Azure Blob Storage. ### Amazon S3 -1. Install the S3 dependency with 🤗 Datasets: +1. Install the S3 FileSystem implementation: ``` ->>> pip install datasets[s3] +>>> pip install s3fs ``` 2. Define your credentials @@ -89,7 +89,7 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever ```py >>> storage_options = {"anon": True} # for anonymous connection # or use your credentials ->>> storage_options = {"account_name": ACCOUNT_NAME, "account_key": ACCOUNT_KEY) # gen 2 filesystem +>>> storage_options = {"account_name": ACCOUNT_NAME, "account_key": ACCOUNT_KEY} # gen 2 filesystem # or use your credentials with the gen 1 filesystem >>> storage_options={"tenant_id": TENANT_ID, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET} ``` @@ -173,11 +173,11 @@ After you have processed your dataset, you can save it to your cloud storage wit ```py # saves encoded_dataset to amazon s3 ->>> encoded_dataset.save_to_disk("s3://my-private-datasets/imdb/train", fs=fs) +>>> encoded_dataset.save_to_disk("s3://my-private-datasets/imdb/train", storage_options=storage_options) # saves encoded_dataset to google cloud storage ->>> encoded_dataset.save_to_disk("gcs://my-private-datasets/imdb/train", fs=fs) +>>> encoded_dataset.save_to_disk("gcs://my-private-datasets/imdb/train", storage_options=storage_options) # saves encoded_dataset to microsoft azure blob/datalake ->>> encoded_dataset.save_to_disk("adl://my-private-datasets/imdb/train", fs=fs) +>>> encoded_dataset.save_to_disk("adl://my-private-datasets/imdb/train", storage_options=storage_options) ``` @@ -202,7 +202,7 @@ When you are ready to use your dataset again, reload it with [`Dataset.load_from ```py >>> from datasets import load_from_disk # load encoded_dataset from cloud storage ->>> dataset = load_from_disk("s3://a-public-datasets/imdb/train", fs=fs) +>>> dataset = load_from_disk("s3://a-public-datasets/imdb/train", storage_options=storage_options) >>> print(len(dataset)) 25000 ``` From d1d8ef8088a84fa12b51412212ad7c67d68f7937 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 22 Nov 2022 19:02:56 +0100 Subject: [PATCH 13/25] backport to_reader to pyarrow < 8 --- src/datasets/arrow_dataset.py | 28 +++++++--------------- src/datasets/table.py | 44 ++++++++++++++++++++--------------- tests/test_table.py | 17 +++++++------- 3 files changed, 42 insertions(+), 47 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5cd802d0b04..5be506d02f0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1347,23 +1347,13 @@ def _save_to_disk_single(arg): ) try: _time = time.time() - if config.PYARROW_VERSION.major >= 8: - for pa_table in table_iter(shard.data.table, batch_size=batch_size): - writer.write_table(pa_table) - num_examples_progress_update += len(pa_table) - if time.time() > _time + refresh_rate: - _time = time.time() - yield job_id, False, num_examples_progress_update - num_examples_progress_update = 0 - else: - for i in range(0, shard.num_rows, batch_size): - pa_table = shard.data.slice(i, batch_size) - writer.write_table(pa_table) - num_examples_progress_update += len(pa_table) - if time.time() > _time + refresh_rate: - _time = time.time() - yield job_id, False, num_examples_progress_update - num_examples_progress_update = 0 + for pa_table in table_iter(shard.data, batch_size=batch_size): + writer.write_table(pa_table) + num_examples_progress_update += len(pa_table) + if time.time() > _time + refresh_rate: + _time = time.time() + yield job_id, False, num_examples_progress_update + num_examples_progress_update = 0 finally: yield job_id, False, num_examples_progress_update num_examples, num_bytes = writer.finalize() @@ -2046,7 +2036,7 @@ def _iter_batches(self, batch_size: int, decoded: bool = True, drop_last_batch: If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the selected format. """ - if self._indices is None and config.PYARROW_VERSION.major >= 8: + if self._indices is None: # Fast iteration # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} @@ -2073,7 +2063,7 @@ def _iter(self, decoded: bool = True): If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the selected format. """ - if self._indices is None and config.PYARROW_VERSION.major >= 8: + if self._indices: # Fast iteration # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} diff --git a/src/datasets/table.py b/src/datasets/table.py index 95450ead3d3..0cc84f25d99 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -4,7 +4,7 @@ import warnings from functools import partial from itertools import groupby -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple, TypeVar, Union import numpy as np import pyarrow as pa @@ -100,8 +100,10 @@ def _interpolation_search(arr: List[int], x: int) -> int: class IndexedTableMixin: def __init__(self, table: pa.Table): - self._schema = table.schema - self._batches = [recordbatch for recordbatch in table.to_batches() if len(recordbatch) > 0] + self._schema: pa.Schema = table.schema + self._batches: List[pa.RecordBatch] = [ + recordbatch for recordbatch in table.to_batches() if len(recordbatch) > 0 + ] self._offsets: np.ndarray = np.cumsum([0] + [len(b) for b in self._batches], dtype=np.int64) def fast_gather(self, indices: Union[List[int], np.ndarray]) -> pa.Table: @@ -145,6 +147,20 @@ def fast_slice(self, offset=0, length=None) -> pa.Table: return pa.Table.from_batches(batches, schema=self._schema) +class _RecordBatchReader: + def __init__(self, table: "Table", max_chunksize: Optional[int] = None): + self.table = table + self.max_chunksize = max_chunksize + + def __iter__(self): + for batch in self.table._batches: + if self.max_chunksize is None or len(batch) <= self.max_chunksize: + yield batch + else: + for offset in range(0, len(batch), self.max_chunksize): + yield batch.slice(offset, self.max_chunksize) + + class Table(IndexedTableMixin): """ Wraps a pyarrow Table by using composition. @@ -330,7 +346,7 @@ def to_pandas(self, *args, **kwargs): def to_string(self, *args, **kwargs): return self.table.to_string(*args, **kwargs) - def to_reader(self, *args, **kwargs): + def to_reader(self, max_chunksize: Optional[int] = None): """ Convert the Table to a RecordBatchReader. @@ -342,17 +358,11 @@ def to_reader(self, *args, **kwargs): on the chunk layout of individual columns. Returns: - :obj:`pyarrow.RecordBatchReader` - - - - pyarrow >= 8.0.0 needs to be installed to use this method. - - + :obj:`pyarrow.RecordBatchReader` if pyarrow>=8.0.0, otherwise a :obj:`pyarrow.RecordBatch` iterable """ if config.PYARROW_VERSION.major < 8: - raise NotImplementedError("`pyarrow>=8.0.0` is required to use this method") - return self.table.to_reader(*args, **kwargs) + return _RecordBatchReader(self, max_chunksize=max_chunksize) + return self.table.to_reader(max_chunksize=max_chunksize) def field(self, *args, **kwargs): """ @@ -2192,21 +2202,17 @@ def _visit(array, feature): _visit(table[name], feature) -def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): +def table_iter(table: Table, batch_size: int, drop_last_batch=False) -> Iterator[pa.Table]: """Iterate over sub-tables of size `batch_size`. - Requires pyarrow>=8.0.0 - Args: table (:obj:`pyarrow.Table`): PyArrow table to iterate over batch_size (:obj:`int`): size of each sub-table to yield drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size` """ - if config.PYARROW_VERSION.major < 8: - raise RuntimeError(f"pyarrow>=8.0.0 is needed to use table_iter but you have {config.PYARROW_VERSION}") chunks_buffer = [] chunks_buffer_size = 0 - for chunk in pa_table.to_reader(max_chunksize=batch_size): + for chunk in table.to_reader(max_chunksize=batch_size): if len(chunk) == 0: continue elif chunks_buffer_size + len(chunk) < batch_size: diff --git a/tests/test_table.py b/tests/test_table.py index 0ff87a3c723..8f767c38e10 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1130,21 +1130,20 @@ def test_embed_table_storage(image_file): assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes) -@pytest.mark.skipif(datasets.config.PYARROW_VERSION.major < 8, reason="only available on pyarrow>=8") @pytest.mark.parametrize( - "pa_table", + "table", [ - pa.table({"foo": range(10)}), - pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})]), - pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)]), + InMemoryTable(pa.table({"foo": range(10)})), + InMemoryTable(pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})])), + InMemoryTable(pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)])), ], ) @pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) @pytest.mark.parametrize("drop_last_batch", [False, True]) -def test_table_iter(pa_table, batch_size, drop_last_batch): - num_rows = len(pa_table) if not drop_last_batch else len(pa_table) // batch_size * batch_size +def test_table_iter(table, batch_size, drop_last_batch): + num_rows = len(table) if not drop_last_batch else len(table) // batch_size * batch_size num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - subtables = list(table_iter(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) + subtables = list(table_iter(table, batch_size=batch_size, drop_last_batch=drop_last_batch)) assert len(subtables) == num_batches if drop_last_batch: assert all(len(subtable) == batch_size for subtable in subtables) @@ -1153,4 +1152,4 @@ def test_table_iter(pa_table, batch_size, drop_last_batch): assert len(subtables[-1]) <= batch_size if num_rows > 0: reloaded = pa.concat_tables(subtables) - assert pa_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + assert table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() From 705779276ad8363383d9f85e2227d61c5508d923 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 22 Nov 2022 19:04:43 +0100 Subject: [PATCH 14/25] typo --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5be506d02f0..bb47f60f7b3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2063,7 +2063,7 @@ def _iter(self, decoded: bool = True): If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the selected format. """ - if self._indices: + if self._indices is None: # Fast iteration # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} From 5e737c024307e074c168a6e58b26e6f1d032e9c3 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 7 Dec 2022 16:36:23 +0100 Subject: [PATCH 15/25] support both max_shard_size and num_shards --- src/datasets/arrow_dataset.py | 100 ++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 30 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bb47f60f7b3..1ccafd3ad34 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1185,6 +1185,7 @@ def save_to_disk( self, dataset_path: PathLike, fs="deprecated", + max_shard_size: Optional[Union[str, int]] = None, num_shards: Optional[int] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, @@ -1201,6 +1202,9 @@ def save_to_disk( Args: dataset_path (``PathLike``): Path (e.g. `path/to/dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset directory where the dataset will be saved to. + max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): + The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). num_shards (:obj:`int`, optional): Number of shards to write. Default to the same value as `num_proc` if specified. @@ -1219,8 +1223,8 @@ def save_to_disk( >>> saved_ds = ds.save_to_disk("path/to/dataset/directory") ``` """ - num_proc: int = num_proc if num_proc is not None else 1 - num_shards: int = num_shards if num_shards is not None else num_proc + if max_shard_size is not None and num_shards is not None: + raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") if fs != "deprecated": warnings.warn( "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" @@ -1229,6 +1233,15 @@ def save_to_disk( ) storage_options = fs.storage_options + if num_shards is None: + dataset_nbytes = self._estimate_nbytes() + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + num_shards = int(dataset_nbytes / max_shard_size) + 1 + num_shards = max(num_shards, num_proc or 1) + + num_proc = num_proc if num_proc is not None else 1 + num_shards = num_shards if num_shards is not None else num_proc + fs_token_paths = fsspec.get_fs_token_paths(dataset_path, storage_options=storage_options) fs: fsspec.AbstractFileSystem = fs_token_paths[0] is_local = not is_remote_filesystem(fs) @@ -4374,6 +4387,38 @@ def to_sql( return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write() + def _estimate_nbytes(self) -> int: + dataset_nbytes = self.data.nbytes + + # Find decodable columns, because if there are any, we need to + # adjust the dataset size computation (needed for sharding) to account for possible external files + decodable_columns = [ + k for k, v in self.features.items() if require_decoding(v, ignore_decode_attribute=True) + ] + + if decodable_columns: + # Approximate the space needed to store the bytes from the external files by analyzing the first 1000 examples + extra_nbytes = 0 + + def extra_nbytes_visitor(array, feature): + nonlocal extra_nbytes + if isinstance(feature, (Audio, Image)): + for x in array.to_pylist(): + if x is not None and x["bytes"] is None and x["path"] is not None: + size = xgetsize(x["path"]) + extra_nbytes += size + extra_nbytes -= array.field("path").nbytes + + table = self.with_format("arrow")[:1000] + table_visitor(table, extra_nbytes_visitor) + + extra_nbytes = extra_nbytes * len(self.data) / len(table) + dataset_nbytes = dataset_nbytes + extra_nbytes + + if self._indices is not None: + dataset_nbytes = dataset_nbytes * len(self._indices) / len(self.data) + return dataset_nbytes + def _push_parquet_shards_to_hub( self, repo_id: str, @@ -4382,6 +4427,7 @@ def _push_parquet_shards_to_hub( token: Optional[str] = None, branch: Optional[str] = None, max_shard_size: Optional[Union[int, str]] = None, + num_shards: Optional[int] = None, embed_external_files: bool = True, ) -> Tuple[str, str, int, int]: """Pushes the dataset to the hub. @@ -4407,6 +4453,10 @@ def _push_parquet_shards_to_hub( max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + num_shards (:obj:`int`, optional): Number of shards to write. + Default to the same value as `num_proc` if specified. + + embed_external_files (:obj:`bool`, default ``True``): Whether to embed file bytes in the shards. In particular, this will do the following before the push for the fields of type: @@ -4427,7 +4477,8 @@ def _push_parquet_shards_to_hub( >>> dataset.push_to_hub("/", split="evaluation") ``` """ - max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + if max_shard_size is not None and num_shards is not None: + raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") api = HfApi(endpoint=config.HF_ENDPOINT) token = token if token is not None else HfFolder.get_token() @@ -4465,40 +4516,20 @@ def _push_parquet_shards_to_hub( ) # Find decodable columns, because if there are any, we need to: - # (1) adjust the dataset size computation (needed for sharding) to account for possible external files - # (2) embed the bytes from the files in the shards + # embed the bytes from the files in the shards decodable_columns = ( [k for k, v in self.features.items() if require_decoding(v, ignore_decode_attribute=True)] if embed_external_files else [] ) - dataset_nbytes = self.data.nbytes - - if decodable_columns: - # Approximate the space needed to store the bytes from the external files by analyzing the first 1000 examples - extra_nbytes = 0 - - def extra_nbytes_visitor(array, feature): - nonlocal extra_nbytes - if isinstance(feature, (Audio, Image)): - for x in array.to_pylist(): - if x is not None and x["bytes"] is None and x["path"] is not None: - size = xgetsize(x["path"]) - extra_nbytes += size - extra_nbytes -= array.field("path").nbytes - - table = self.with_format("arrow")[:1000] - table_visitor(table, extra_nbytes_visitor) - - extra_nbytes = extra_nbytes * len(self.data) / len(table) - dataset_nbytes = dataset_nbytes + extra_nbytes + dataset_nbytes = self._estimate_nbytes() - if self._indices is not None: - dataset_nbytes = dataset_nbytes * len(self._indices) / len(self.data) + if num_shards is None: + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + num_shards = int(dataset_nbytes / max_shard_size) + 1 + num_shards = max(num_shards, 1) - num_shards = int(dataset_nbytes / max_shard_size) + 1 - num_shards = max(num_shards, 1) shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)) if decodable_columns: @@ -4596,6 +4627,7 @@ def push_to_hub( token: Optional[str] = None, branch: Optional[str] = None, max_shard_size: Optional[Union[int, str]] = None, + num_shards: Optional[int] = None, shard_size: Optional[int] = "deprecated", embed_external_files: bool = True, ): @@ -4626,8 +4658,12 @@ def push_to_hub( max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + num_shards (:obj:`int`, optional): Number of shards to write. + Default to the same value as `num_proc` if specified. + + shard_size (Optional :obj:`int`): - Deprecated: 'shard_size' was renamed to 'max_shard_size' in version 2.1.1 and will be removed in 2.4.0. + Deprecated: 'shard_size' was renamed to 'max_shard_size' in version 2.1.1 and will be removed in 3.0.0. embed_external_files (:obj:`bool`, default ``True``): Whether to embed file bytes in the shards. In particular, this will do the following before the push for the fields of type: @@ -4647,6 +4683,9 @@ def push_to_hub( ) max_shard_size = shard_size + if max_shard_size is not None and num_shards is not None: + raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") + repo_id, split, uploaded_size, dataset_nbytes, repo_files, deleted_size = self._push_parquet_shards_to_hub( repo_id=repo_id, split=split, @@ -4654,6 +4693,7 @@ def push_to_hub( token=token, branch=branch, max_shard_size=max_shard_size, + num_shards=num_shards, embed_external_files=embed_external_files, ) organization, dataset_name = repo_id.split("/") From 598b9daf940989991bebb6dbe99479c85ec4aae1 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 7 Dec 2022 16:36:46 +0100 Subject: [PATCH 16/25] style --- src/datasets/arrow_dataset.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 1ccafd3ad34..e6a34b713e4 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1224,7 +1224,9 @@ def save_to_disk( ``` """ if max_shard_size is not None and num_shards is not None: - raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") + raise ValueError( + "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." + ) if fs != "deprecated": warnings.warn( "'fs' was is deprecated in favor of 'storage_options' in version 2.8.0 and will be removed in 3.0.0.\n" @@ -4392,9 +4394,7 @@ def _estimate_nbytes(self) -> int: # Find decodable columns, because if there are any, we need to # adjust the dataset size computation (needed for sharding) to account for possible external files - decodable_columns = [ - k for k, v in self.features.items() if require_decoding(v, ignore_decode_attribute=True) - ] + decodable_columns = [k for k, v in self.features.items() if require_decoding(v, ignore_decode_attribute=True)] if decodable_columns: # Approximate the space needed to store the bytes from the external files by analyzing the first 1000 examples @@ -4478,7 +4478,9 @@ def _push_parquet_shards_to_hub( ``` """ if max_shard_size is not None and num_shards is not None: - raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") + raise ValueError( + "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." + ) api = HfApi(endpoint=config.HF_ENDPOINT) token = token if token is not None else HfFolder.get_token() @@ -4684,7 +4686,9 @@ def push_to_hub( max_shard_size = shard_size if max_shard_size is not None and num_shards is not None: - raise ValueError("Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both.") + raise ValueError( + "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." + ) repo_id, split, uploaded_size, dataset_nbytes, repo_files, deleted_size = self._push_parquet_shards_to_hub( repo_id=repo_id, From 24e24bf3fedc8f8c19713bcd6f9d242e2690ecf0 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 7 Dec 2022 16:48:15 +0100 Subject: [PATCH 17/25] docstrings --- src/datasets/arrow_dataset.py | 15 ++++++++++----- src/datasets/dataset_dict.py | 34 ++++++++++++++++++++++++++++++++-- src/datasets/utils/py_utils.py | 2 +- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index e6a34b713e4..3e23cb7fbd6 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1204,7 +1204,7 @@ def save_to_disk( of the dataset directory where the dataset will be saved to. max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit - (like `"5MB"`). + (like `"50MB"`). num_shards (:obj:`int`, optional): Number of shards to write. Default to the same value as `num_proc` if specified. @@ -1220,7 +1220,9 @@ def save_to_disk( Example: ```py - >>> saved_ds = ds.save_to_disk("path/to/dataset/directory") + >>> ds.save_to_disk("path/to/dataset/directory") + >>> ds.save_to_disk("path/to/dataset/directory", max_shard_size="1GB") + >>> ds.save_to_disk("path/to/dataset/directory", num_shards=1024) ``` """ if max_shard_size is not None and num_shards is not None: @@ -4452,7 +4454,7 @@ def _push_parquet_shards_to_hub( in your repository, which defaults to `"main"`. max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit - (like `"5MB"`). + (like `"50MB"`). num_shards (:obj:`int`, optional): Number of shards to write. Default to the same value as `num_proc` if specified. @@ -4659,7 +4661,7 @@ def push_to_hub( in your repository, which defaults to `"main"`. max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit - (like `"5MB"`). + (like `"50MB"`). num_shards (:obj:`int`, optional): Number of shards to write. Default to the same value as `num_proc` if specified. @@ -4675,7 +4677,10 @@ def push_to_hub( Example: ```python - >>> dataset.push_to_hub("/", split="evaluation") + >>> dataset.push_to_hub("/") + >>> dataset.push_to_hub("/", split="validation") + >>> dataset.push_to_hub("/", max_shard_size="1GB") + >>> dataset.push_to_hub("/", num_shards=1024) ``` """ if shard_size != "deprecated": diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index c9ac0eff820..362eb9380a7 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1037,7 +1037,8 @@ def save_to_disk( self, dataset_dict_path: PathLike, fs="deprecated", - num_shards: Optional[Dict[str, int]] = None, + max_shard_size: Optional[Union[str, int]] = None, + num_shards: Optional[Union[int, Dict[str, int]]] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, ): @@ -1054,9 +1055,13 @@ def save_to_disk( dataset_dict_path (``PathLike``): Path (e.g. `path/to/dataset`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be saved to. + max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): + The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit + (like `"50MB"`). num_shards (:obj:`Dict[str, int]`, optional): Number of shards to write. You need to provide the number of shards for each dataset in the dataset dictionary. Default to the same value as `num_proc` if specified. + Use a dictionary to define a different num_shards for each split. num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. @@ -1067,6 +1072,13 @@ def save_to_disk( + Example: + + ```python + >>> dataset_dict.save_to_disk("path/to/dataset/directory") + >>> dataset_dict.save_to_disk("path/to/dataset/directory", max_shard_size="1GB") + >>> dataset_dict.save_to_disk("path/to/dataset/directory", num_shards={"train": 1024, "test": 8}) + ``` """ if fs != "deprecated": warnings.warn( @@ -1099,6 +1111,7 @@ def save_to_disk( dataset.save_to_disk( path_join(dataset_dict_path, k), num_shards=num_shards.get(k), + max_shard_size=max_shard_size, num_proc=num_proc, storage_options=storage_options, ) @@ -1336,7 +1349,8 @@ def push_to_hub( token: Optional[str] = None, branch: Optional[None] = None, max_shard_size: Optional[Union[int, str]] = None, - shard_size: Optional[int] = "deprecated", + num_shards: Optional[int] = None, + shard_size: Optional[Union[int, Dict[str, int]]] = "deprecated", embed_external_files: bool = True, ): """Pushes the ``DatasetDict`` to the hub as a Parquet dataset. @@ -1365,6 +1379,11 @@ def push_to_hub( max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"500MB"` or `"1GB"`). + num_shards (`Union[int, Dict[str, int]]`, optional): Number of shards to write. + Default to the same value as `num_proc` if specified. + Use a dictionary to define a different num_shards for each split. + + shard_size (Optional :obj:`int`): Deprecated: 'shard_size' was renamed to 'max_shard_size' in version 2.1.1 and will be removed in 2.4.0. embed_external_files (:obj:`bool`, default ``True``): @@ -1377,6 +1396,9 @@ def push_to_hub( ```python >>> dataset_dict.push_to_hub("/") + >>> dataset_dict.push_to_hub("/", private=True) + >>> dataset_dict.push_to_hub("/", max_shard_size="1GB") + >>> dataset_dict.push_to_hub("/", num_shards={"train": 1024, "test": 8}) ``` """ if shard_size != "deprecated": @@ -1386,6 +1408,13 @@ def push_to_hub( ) max_shard_size = shard_size + if num_shards is None: + num_shards = {k: None for k in self} + elif not isinstance(num_shards, dict): + raise ValueError( + "Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}" + ) + self._check_values_type() self._check_values_features() total_uploaded_size = 0 @@ -1407,6 +1436,7 @@ def push_to_hub( token=token, branch=branch, max_shard_size=max_shard_size, + num_shards=num_shards.get(split), embed_external_files=embed_external_files, ) total_uploaded_size += uploaded_size diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index ee33b358abf..0c628cb6d9b 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -90,7 +90,7 @@ def size_str(size_in_bytes): def convert_file_size_to_int(size: Union[int, str]) -> int: """ - Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + Converts a size expressed as a string with digits an unit (like `"50MB"`) to an integer (in bytes). Args: size (`int` or `str`): The size to convert. Will be directly returned if an `int`. From 75347aa78e1be471b48b02b4cf540603ca2d0e7c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 9 Dec 2022 15:27:34 +0100 Subject: [PATCH 18/25] test _estimate_nbytes --- tests/test_arrow_dataset.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 9cd218ca24e..51f11e68294 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -332,6 +332,17 @@ def test_dummy_dataset_serialize(self, in_memory): self.assertDictEqual(dset.to_dict(), expected) self.assertEqual(len(dset.cache_files), 7) + with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset: + with assert_arrow_memory_doesnt_increase(): + max_shard_size = dset._estimate_nbytes() // 2 + 1 + dset.save_to_disk(dataset_path, max_shard_size=max_shard_size) + + with Dataset.load_from_disk(dataset_path) as dset: + self.assertEqual(len(dset), 10) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset.to_dict(), expected) + self.assertEqual(len(dset.cache_files), 2) + def test_dummy_dataset_load_from_disk(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: @@ -4110,3 +4121,20 @@ def test_train_test_split_startify(self): assert len(d1["train"]["text"]) + len(d1["test"]["text"]) == y.size assert len(d1["train"]["text"]) == train_size assert len(d1["test"]["text"]) == test_size + + +def test_dataset_estimate_nbytes(): + ds = Dataset.from_dict({"a": ["0" * 100] * 100}) + assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than full dataset size" + + ds = Dataset.from_dict({"a": ["0" * 100] * 100}).select([0]) + assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk" + + ds = Dataset.from_dict({"a": ["0" * 100] * 100}) + ds = concatenate_datasets([ds] * 100) + assert 0.9 * ds._estimate_nbytes() < 100 * 100 * 100, "must be smaller than full dataset size" + assert 1.1 * ds._estimate_nbytes() > 100 * 100 * 100, "must be bigger than full dataset size" + + ds = Dataset.from_dict({"a": ["0" * 100] * 100}) + ds = concatenate_datasets([ds] * 100).select([0]) + assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk" From fc39b8399740efb86435a4414030f3c7cebd240d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 12 Dec 2022 19:12:33 +0100 Subject: [PATCH 19/25] add test for num_shards --- tests/test_upstream_hub.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_upstream_hub.py b/tests/test_upstream_hub.py index 8ab5ff885e2..b7b8f7bed0a 100644 --- a/tests/test_upstream_hub.py +++ b/tests/test_upstream_hub.py @@ -184,6 +184,34 @@ def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, tempo ) ) + def test_push_dataset_dict_to_hub_multiple_files_with_num_shards(self, temporary_repo): + ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) + + local_ds = DatasetDict({"train": ds}) + + with temporary_repo(f"{CI_HUB_USER}/test-{int(time.time() * 10e3)}") as ds_name: + local_ds.push_to_hub(ds_name, token=self._token, num_shards={"train": 2}) + hub_ds = load_dataset(ds_name, download_mode="force_redownload") + + assert local_ds.column_names == hub_ds.column_names + assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there are two files on the repository that have the correct name + files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token)) + assert all( + fnmatch.fnmatch(file, expected_file) + for file, expected_file in zip( + files, + [ + ".gitattributes", + "README.md", + "data/train-00000-of-00002-*.parquet", + "data/train-00001-of-00002-*.parquet", + ], + ) + ) + def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) ds2 = Dataset.from_dict({"x": list(range(100)), "y": list(range(100))}) From d004f5866e065c1211367ac393abf28802e9e01f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 13 Dec 2022 12:32:39 +0100 Subject: [PATCH 20/25] style --- src/datasets/arrow_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 97001dd7a27..9d88d81795d 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4724,9 +4724,9 @@ def _push_parquet_shards_to_hub( The git branch on which to push the dataset. This defaults to the default branch as specified in your repository, which defaults to `"main"`. max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): - The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a + The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a a unit (like `"5MB"`). - num_shards (`int`, *optional*): + num_shards (`int`, *optional*): Number of shards to write. Default to the same value as `num_proc` if specified. From c1db7bd19471beb2cc91320ced1974e7694b3c23 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:01:06 +0100 Subject: [PATCH 21/25] mario's comment --- src/datasets/arrow_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 16c7221c32d..3d8f9d9aece 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1407,7 +1407,7 @@ def save_to_disk( for job_id, done, content in iflatmap_unordered(pool, Dataset._save_to_disk_single, args_per_job): if done: shards_done += 1 - pbar.set_description(f"Saving dataset ({shards_done}/{num_shards} shards)") + pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") logger.debug(f"Finished writing shard number {job_id} of {num_shards}.") shard_lengths[job_id], shard_sizes[job_id] = content else: @@ -1417,7 +1417,7 @@ def save_to_disk( for job_id, done, content in Dataset._save_to_disk_single(args): if done: shards_done += 1 - pbar.set_description(f"Saving dataset ({shards_done}/{num_shards} shards)") + pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") logger.debug(f"Finished writing shard number {job_id} of {num_shards}.") shard_lengths[job_id], shard_sizes[job_id] = content else: From f3562d201b638d0cc9cd47635dfb86cb4927cd75 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:06:14 +0100 Subject: [PATCH 22/25] add config.PBAR_REFRESH_TIME_INTERVAL --- src/datasets/arrow_dataset.py | 3 +-- src/datasets/builder.py | 6 ++---- src/datasets/config.py | 3 +++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 3d8f9d9aece..7eb59643f78 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1437,7 +1437,6 @@ def _save_to_disk_single(arg): shard: Dataset = arg["shard"] fpath: str = arg["fpath"] storage_options: Optional[dict] = arg["storage_options"] - refresh_rate = 0.05 # 20 progress updates per sec batch_size = config.DEFAULT_MAX_BATCH_SIZE if shard._indices is not None: @@ -1458,7 +1457,7 @@ def _save_to_disk_single(arg): for pa_table in table_iter(shard.data, batch_size=batch_size): writer.write_table(pa_table) num_examples_progress_update += len(pa_table) - if time.time() > _time + refresh_rate: + if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: _time = time.time() yield job_id, False, num_examples_progress_update num_examples_progress_update = 0 diff --git a/src/datasets/builder.py b/src/datasets/builder.py index d336465cd5f..8af1fd953b7 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1542,7 +1542,6 @@ def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[in split_info: SplitInfo = arg["split_info"] check_duplicate_keys: bool = arg["check_duplicate_keys"] job_id: int = arg["job_id"] - refresh_rate = 0.05 # 20 progress updates per sec generator = self._generate_examples(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1584,7 +1583,7 @@ def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[in example = self.info.features.encode_example(record) if self.info.features is not None else record writer.write(example, key) num_examples_progress_update += 1 - if time.time() > _time + refresh_rate: + if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: _time = time.time() yield job_id, False, num_examples_progress_update num_examples_progress_update = 0 @@ -1795,7 +1794,6 @@ def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[in file_format: str = arg["file_format"] max_shard_size: int = arg["max_shard_size"] job_id: int = arg["job_id"] - refresh_rate = 0.05 # 20 progress updates per sec generator = self._generate_tables(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1830,7 +1828,7 @@ def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[in ) writer.write_table(table) num_examples_progress_update += len(table) - if time.time() > _time + refresh_rate: + if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: _time = time.time() yield job_id, False, num_examples_progress_update num_examples_progress_update = 0 diff --git a/src/datasets/config.py b/src/datasets/config.py index bbfdddae7a0..443dec643e9 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -210,3 +210,6 @@ DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200 GLOBBED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 10 ARCHIVED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200 + +# Progress bars +PBAR_REFRESH_TIME_INTERVAL = 0.05 # 20 progress updates per sec From c2b38fa1a4b36b66dde24e858a39f8dc3748f491 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:11:48 +0100 Subject: [PATCH 23/25] fix docstrings --- src/datasets/arrow_dataset.py | 9 +++------ src/datasets/dataset_dict.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 7eb59643f78..a9f05778f2a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1294,8 +1294,7 @@ def save_to_disk( The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"50MB"`). num_shards (`int`, *optional*): - Number of shards to write. - Default to the same value as `num_proc` if specified. + Number of shards to write. By default the number of shards depends on `max_shard_size`. num_proc (`int`, *optional*): @@ -4679,8 +4678,7 @@ def _push_parquet_shards_to_hub( The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a a unit (like `"5MB"`). num_shards (`int`, *optional*): - Number of shards to write. - Default to the same value as `num_proc` if specified. + Number of shards to write. By default the number of shards depends on `max_shard_size`. embed_external_files (`bool`, default ``True``): @@ -4886,8 +4884,7 @@ def push_to_hub( max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - num_shards (`int`, *optional*): Number of shards to write. - Default to the same value as `num_proc` if specified. + num_shards (`int`, *optional*): Number of shards to write. By default the number of shards depends on `max_shard_size`. shard_size (`int`, *optional*): diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 3f928be8304..0881ab5b706 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1102,7 +1102,7 @@ def save_to_disk( dataset_dict_path: PathLike, fs="deprecated", max_shard_size: Optional[Union[str, int]] = None, - num_shards: Optional[Union[int, Dict[str, int]]] = None, + num_shards: Optional[Dict[str, int]] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, ): @@ -1134,9 +1134,8 @@ def save_to_disk( The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"50MB"`). num_shards (`Dict[str, int]`, *optional*): - Number of shards to write. + Number of shards to write. By default the number of shards depends on `max_shard_size`. You need to provide the number of shards for each dataset in the dataset dictionary. - Default to the same value as `num_proc` if specified. Use a dictionary to define a different num_shards for each split. @@ -1460,8 +1459,8 @@ def push_to_hub( token: Optional[str] = None, branch: Optional[None] = None, max_shard_size: Optional[Union[int, str]] = None, - num_shards: Optional[int] = None, - shard_size: Optional[Union[int, Dict[str, int]]] = "deprecated", + num_shards: Optional[Dict[str, int]] = None, + shard_size: Optional[Union[int, str]] = "deprecated", embed_external_files: bool = True, ): """Pushes the [`DatasetDict`] to the hub as a Parquet dataset. @@ -1490,13 +1489,12 @@ def push_to_hub( max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"500MB"` or `"1GB"`). - num_shards (`Union[int, Dict[str, int]]`, *optional*): - Number of shards to write. - Default to the same value as `num_proc` if specified. + num_shards (`Dict[str, int]`, *optional*): + Number of shards to write. By default the number of shards depends on `max_shard_size`. Use a dictionary to define a different num_shards for each split. - shard_size (`int`, *optional*): + shard_size (`int` or `str`, *optional*): From ce667323eda7a46180111229f29846b3e25a1d4c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:30:06 +0100 Subject: [PATCH 24/25] use kwargs_iterable in iflatmap_unordered --- src/datasets/arrow_dataset.py | 16 +++++++-------- src/datasets/builder.py | 37 +++++++++++++++++----------------- src/datasets/utils/py_utils.py | 14 +++++++------ tests/test_py_utils.py | 12 ++++++++--- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index a9f05778f2a..967f07eced0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1390,7 +1390,7 @@ def save_to_disk( leave=False, desc=f"Saving the dataset ({shards_done}/{num_shards} shards)", ) - args_per_job = ( + kwargs_per_job = ( { "job_id": shard_idx, "shard": dataset.shard(num_shards=num_shards, index=shard_idx, contiguous=True), @@ -1403,7 +1403,9 @@ def save_to_disk( shard_sizes = [None] * num_shards if num_proc > 1: with Pool(num_proc) as pool: - for job_id, done, content in iflatmap_unordered(pool, Dataset._save_to_disk_single, args_per_job): + for job_id, done, content in iflatmap_unordered( + pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job + ): if done: shards_done += 1 pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") @@ -1412,8 +1414,8 @@ def save_to_disk( else: pbar.update(content) else: - for args in args_per_job: - for job_id, done, content in Dataset._save_to_disk_single(args): + for kwargs in kwargs_per_job: + for job_id, done, content in Dataset._save_to_disk_single(**kwargs): if done: shards_done += 1 pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") @@ -1431,11 +1433,7 @@ def save_to_disk( json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2) @staticmethod - def _save_to_disk_single(arg): - job_id: Dataset = arg["job_id"] - shard: Dataset = arg["shard"] - fpath: str = arg["fpath"] - storage_options: Optional[dict] = arg["storage_options"] + def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_options: Optional[dict]): batch_size = config.DEFAULT_MAX_BATCH_SIZE if shard._indices is not None: diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 8af1fd953b7..ea509eca0b7 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1447,7 +1447,7 @@ def _prepare_split( gen_kwargs = split_generator.gen_kwargs job_id = 0 for job_id, done, content in self._prepare_split_single( - {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args ): if done: result = content @@ -1459,13 +1459,13 @@ def _prepare_split( [item] for item in result ] else: - args_per_job = [ + kwargs_per_job = [ {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} for job_id, gen_kwargs in enumerate( _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) ) ] - num_jobs = len(args_per_job) + num_jobs = len(kwargs_per_job) examples_per_job = [None] * num_jobs bytes_per_job = [None] * num_jobs @@ -1474,7 +1474,9 @@ def _prepare_split( shard_lengths_per_job = [None] * num_jobs with Pool(num_proc) as pool: - for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job): + for job_id, done, content in iflatmap_unordered( + pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job + ): if done: # the content is the result of the job ( @@ -1534,14 +1536,16 @@ def _rename_shard(shard_and_job: Tuple[int]): if self.info.features is None: self.info.features = features - def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: - gen_kwargs: dict = arg["gen_kwargs"] - fpath: str = arg["fpath"] - file_format: str = arg["file_format"] - max_shard_size: int = arg["max_shard_size"] - split_info: SplitInfo = arg["split_info"] - check_duplicate_keys: bool = arg["check_duplicate_keys"] - job_id: int = arg["job_id"] + def _prepare_split_single( + self, + gen_kwargs: dict, + fpath: str, + file_format: str, + max_shard_size: int, + split_info: SplitInfo, + check_duplicate_keys: bool, + job_id: int, + ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: generator = self._generate_examples(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1788,12 +1792,9 @@ def _rename_shard(shard_id_and_job: Tuple[int]): if self.info.features is None: self.info.features = features - def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: - gen_kwargs: dict = arg["gen_kwargs"] - fpath: str = arg["fpath"] - file_format: str = arg["file_format"] - max_shard_size: int = arg["max_shard_size"] - job_id: int = arg["job_id"] + def _prepare_split_single( + self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int + ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: generator = self._generate_tables(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 0c628cb6d9b..992ed5c0092 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -1335,25 +1335,27 @@ def copyfunc(func): return result -X = TypeVar("X") Y = TypeVar("Y") -def _write_generator_to_queue(queue: queue.Queue, func: Callable[[X], Iterable[Y]], arg: X) -> int: - for i, result in enumerate(func(arg)): +def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y]], kwargs: dict) -> int: + for i, result in enumerate(func(**kwargs)): queue.put(result) return i def iflatmap_unordered( pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], - func: Callable[[X], Iterable[Y]], - iterable: Iterable[X], + func: Callable[..., Iterable[Y]], + *, + kwargs_iterable: Iterable[dict], ) -> Iterable[Y]: manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager with manager_cls() as manager: queue = manager.Queue() - async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, arg)) for arg in iterable] + async_results = [ + pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable + ] while True: try: yield queue.get(timeout=0.05) diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index f0b391fbf5c..57091b22bfd 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -240,6 +240,10 @@ def test_asdict(): asdict([1, A(x=10, y="foo")]) +def _split_text(text: str): + return text.split() + + def _2seconds_generator_of_2items_with_timing(content): yield (time.time(), content) time.sleep(2) @@ -249,14 +253,14 @@ def _2seconds_generator_of_2items_with_timing(content): def test_iflatmap_unordered(): with Pool(2) as pool: - out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10)) assert out.count("hello") == 10 assert out.count("there") == 10 assert len(out) == 20 # check multiprocess from pathos (uses dill for pickling) with multiprocess.Pool(2) as pool: - out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10)) assert out.count("hello") == 10 assert out.count("there") == 10 assert len(out) == 20 @@ -264,7 +268,9 @@ def test_iflatmap_unordered(): # check that we get items as fast as possible with Pool(2) as pool: out = [] - for yield_time, content in iflatmap_unordered(pool, _2seconds_generator_of_2items_with_timing, ["a", "b"]): + for yield_time, content in iflatmap_unordered( + pool, _2seconds_generator_of_2items_with_timing, kwargs_iterable=[{"content": "a"}, {"content": "b"}] + ): assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded" out.append(content) assert out.count("a") == 2 From 44e515689f399c0dcdfca2a524fc647bd73dbe2a Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:44:37 +0100 Subject: [PATCH 25/25] fix tests --- src/datasets/builder.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index ea509eca0b7..325f0247113 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1697,7 +1697,6 @@ def _prepare_split( "fpath": fpath, "file_format": file_format, "max_shard_size": max_shard_size, - "split_info": split_info, } if num_proc is None or num_proc == 1: @@ -1705,7 +1704,7 @@ def _prepare_split( gen_kwargs = split_generator.gen_kwargs job_id = 0 for job_id, done, content in self._prepare_split_single( - {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args ): if done: result = content @@ -1717,13 +1716,13 @@ def _prepare_split( [item] for item in result ] else: - args_per_job = [ + kwargs_per_job = [ {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} for job_id, gen_kwargs in enumerate( _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) ) ] - num_jobs = len(args_per_job) + num_jobs = len(kwargs_per_job) examples_per_job = [None] * num_jobs bytes_per_job = [None] * num_jobs @@ -1732,7 +1731,9 @@ def _prepare_split( shard_lengths_per_job = [None] * num_jobs with Pool(num_proc) as pool: - for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job): + for job_id, done, content in iflatmap_unordered( + pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job + ): if done: # the content is the result of the job (