From e6a19f3bd62a6cdb2aa94e9824716afe65dc3472 Mon Sep 17 00:00:00 2001 From: Thiago F Pappacena Date: Wed, 5 Jul 2023 19:03:48 -0300 Subject: [PATCH] Remove ProcessPool subclass --- src/datasets/arrow_dataset.py | 7 ++++--- src/datasets/builder.py | 6 +++--- src/datasets/utils/py_utils.py | 26 +------------------------- tests/test_arrow_dataset.py | 4 ++-- 4 files changed, 10 insertions(+), 33 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 63856b65b332..073e47aefca3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -58,6 +58,7 @@ import pyarrow as pa import pyarrow.compute as pc from huggingface_hub import HfApi, HfFolder +from multiprocess import Pool from requests import HTTPError from . import config @@ -112,7 +113,7 @@ from .utils.hub import hf_hub_url from .utils.info_utils import is_small_dataset from .utils.metadata import DatasetMetadata -from .utils.py_utils import Literal, ProcessPool, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values +from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values from .utils.stratify import stratified_shuffle_split_generate_indices from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf from .utils.typing import ListLike, PathLike @@ -1504,7 +1505,7 @@ def save_to_disk( shard_lengths = [None] * num_shards shard_sizes = [None] * num_shards if num_proc > 1: - with ProcessPool(num_proc) as pool: + with Pool(num_proc) as pool: with pbar: for job_id, done, content in iflatmap_unordered( pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job @@ -3166,7 +3167,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: logger.info( f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache." ) - with ProcessPool(len(kwargs_per_job)) as pool: + with Pool(len(kwargs_per_job)) as pool: os.environ = prev_env logger.info(f"Spawning {num_proc} processes") with logging.tqdm( diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 9ed9d6760f72..389dda65e3ae 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -33,6 +33,7 @@ import fsspec import pyarrow as pa +from multiprocess import Pool from tqdm.contrib.concurrent import thread_map from . import config, utils @@ -68,7 +69,6 @@ from .utils.filelock import FileLock from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( - ProcessPool, classproperty, convert_file_size_to_int, has_sufficient_disk_space, @@ -1543,7 +1543,7 @@ def _prepare_split( shards_per_job = [None] * num_jobs shard_lengths_per_job = [None] * num_jobs - with ProcessPool(num_proc) as pool: + with Pool(num_proc) as pool: with pbar: for job_id, done, content in iflatmap_unordered( pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job @@ -1802,7 +1802,7 @@ def _prepare_split( shards_per_job = [None] * num_jobs shard_lengths_per_job = [None] * num_jobs - with ProcessPool(num_proc) as pool: + with Pool(num_proc) as pool: with pbar: for job_id, done, content in iflatmap_unordered( pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index e4d463c9e32b..abac21e5c3a3 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -1330,33 +1330,9 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y return i -class ProcessPool(multiprocess.pool.Pool): - """ - A multiprocess.pool.Pool implementation that keeps track of child process' PIDs, - and can detect if a child process has been restarted. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._last_pool_pids = self._get_current_pool_pids() - self._has_restarted_subprocess = False - - def _get_current_pool_pids(self) -> Set[int]: - return {f.pid for f in self._pool} - - def has_restarted_subprocess(self) -> bool: - if self._has_restarted_subprocess: - # If the pool ever restarted a subprocess, - # we don't check the PIDs again. - return True - current_pids = self._get_current_pool_pids() - self._has_restarted_subprocess = current_pids != self._last_pool_pids - self._last_pool_pids = current_pids - return self._has_restarted_subprocess - def iflatmap_unordered( - pool: ProcessPool, + pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], func: Callable[..., Iterable[Y]], *, kwargs_iterable: Iterable[dict], diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 94c7fc145285..037ba655329d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1352,9 +1352,9 @@ def test_map_caching(self, in_memory): with self._caplog.at_level(WARNING): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: with patch( - "datasets.arrow_dataset.ProcessPool", + "datasets.arrow_dataset.Pool", new_callable=PickableMagicMock, - side_effect=datasets.arrow_dataset.ProcessPool, + side_effect=datasets.arrow_dataset.Pool, ) as mock_pool: with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1: dset_test1_data_files = list(dset_test1.cache_files)