From ce667323eda7a46180111229f29846b3e25a1d4c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 14 Dec 2022 18:30:06 +0100 Subject: [PATCH] use kwargs_iterable in iflatmap_unordered --- src/datasets/arrow_dataset.py | 16 +++++++-------- src/datasets/builder.py | 37 +++++++++++++++++----------------- src/datasets/utils/py_utils.py | 14 +++++++------ tests/test_py_utils.py | 12 ++++++++--- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index a9f05778f2a..967f07eced0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1390,7 +1390,7 @@ def save_to_disk( leave=False, desc=f"Saving the dataset ({shards_done}/{num_shards} shards)", ) - args_per_job = ( + kwargs_per_job = ( { "job_id": shard_idx, "shard": dataset.shard(num_shards=num_shards, index=shard_idx, contiguous=True), @@ -1403,7 +1403,9 @@ def save_to_disk( shard_sizes = [None] * num_shards if num_proc > 1: with Pool(num_proc) as pool: - for job_id, done, content in iflatmap_unordered(pool, Dataset._save_to_disk_single, args_per_job): + for job_id, done, content in iflatmap_unordered( + pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job + ): if done: shards_done += 1 pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") @@ -1412,8 +1414,8 @@ def save_to_disk( else: pbar.update(content) else: - for args in args_per_job: - for job_id, done, content in Dataset._save_to_disk_single(args): + for kwargs in kwargs_per_job: + for job_id, done, content in Dataset._save_to_disk_single(**kwargs): if done: shards_done += 1 pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)") @@ -1431,11 +1433,7 @@ def save_to_disk( json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2) @staticmethod - def _save_to_disk_single(arg): - job_id: Dataset = arg["job_id"] - shard: Dataset = arg["shard"] - fpath: str = arg["fpath"] - storage_options: Optional[dict] = arg["storage_options"] + def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_options: Optional[dict]): batch_size = config.DEFAULT_MAX_BATCH_SIZE if shard._indices is not None: diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 8af1fd953b7..ea509eca0b7 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1447,7 +1447,7 @@ def _prepare_split( gen_kwargs = split_generator.gen_kwargs job_id = 0 for job_id, done, content in self._prepare_split_single( - {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args ): if done: result = content @@ -1459,13 +1459,13 @@ def _prepare_split( [item] for item in result ] else: - args_per_job = [ + kwargs_per_job = [ {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} for job_id, gen_kwargs in enumerate( _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) ) ] - num_jobs = len(args_per_job) + num_jobs = len(kwargs_per_job) examples_per_job = [None] * num_jobs bytes_per_job = [None] * num_jobs @@ -1474,7 +1474,9 @@ def _prepare_split( shard_lengths_per_job = [None] * num_jobs with Pool(num_proc) as pool: - for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job): + for job_id, done, content in iflatmap_unordered( + pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job + ): if done: # the content is the result of the job ( @@ -1534,14 +1536,16 @@ def _rename_shard(shard_and_job: Tuple[int]): if self.info.features is None: self.info.features = features - def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: - gen_kwargs: dict = arg["gen_kwargs"] - fpath: str = arg["fpath"] - file_format: str = arg["file_format"] - max_shard_size: int = arg["max_shard_size"] - split_info: SplitInfo = arg["split_info"] - check_duplicate_keys: bool = arg["check_duplicate_keys"] - job_id: int = arg["job_id"] + def _prepare_split_single( + self, + gen_kwargs: dict, + fpath: str, + file_format: str, + max_shard_size: int, + split_info: SplitInfo, + check_duplicate_keys: bool, + job_id: int, + ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: generator = self._generate_examples(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1788,12 +1792,9 @@ def _rename_shard(shard_id_and_job: Tuple[int]): if self.info.features is None: self.info.features = features - def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: - gen_kwargs: dict = arg["gen_kwargs"] - fpath: str = arg["fpath"] - file_format: str = arg["file_format"] - max_shard_size: int = arg["max_shard_size"] - job_id: int = arg["job_id"] + def _prepare_split_single( + self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int + ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: generator = self._generate_tables(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 0c628cb6d9b..992ed5c0092 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -1335,25 +1335,27 @@ def copyfunc(func): return result -X = TypeVar("X") Y = TypeVar("Y") -def _write_generator_to_queue(queue: queue.Queue, func: Callable[[X], Iterable[Y]], arg: X) -> int: - for i, result in enumerate(func(arg)): +def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y]], kwargs: dict) -> int: + for i, result in enumerate(func(**kwargs)): queue.put(result) return i def iflatmap_unordered( pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], - func: Callable[[X], Iterable[Y]], - iterable: Iterable[X], + func: Callable[..., Iterable[Y]], + *, + kwargs_iterable: Iterable[dict], ) -> Iterable[Y]: manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager with manager_cls() as manager: queue = manager.Queue() - async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, arg)) for arg in iterable] + async_results = [ + pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable + ] while True: try: yield queue.get(timeout=0.05) diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index f0b391fbf5c..57091b22bfd 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -240,6 +240,10 @@ def test_asdict(): asdict([1, A(x=10, y="foo")]) +def _split_text(text: str): + return text.split() + + def _2seconds_generator_of_2items_with_timing(content): yield (time.time(), content) time.sleep(2) @@ -249,14 +253,14 @@ def _2seconds_generator_of_2items_with_timing(content): def test_iflatmap_unordered(): with Pool(2) as pool: - out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10)) assert out.count("hello") == 10 assert out.count("there") == 10 assert len(out) == 20 # check multiprocess from pathos (uses dill for pickling) with multiprocess.Pool(2) as pool: - out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10)) assert out.count("hello") == 10 assert out.count("there") == 10 assert len(out) == 20 @@ -264,7 +268,9 @@ def test_iflatmap_unordered(): # check that we get items as fast as possible with Pool(2) as pool: out = [] - for yield_time, content in iflatmap_unordered(pool, _2seconds_generator_of_2items_with_timing, ["a", "b"]): + for yield_time, content in iflatmap_unordered( + pool, _2seconds_generator_of_2items_with_timing, kwargs_iterable=[{"content": "a"}, {"content": "b"}] + ): assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded" out.append(content) assert out.count("a") == 2