Skip to content

Commit

Permalink
Remove ProcessPool subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
pappacena committed Jul 5, 2023
1 parent 285603d commit e6a19f3
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 33 deletions.
7 changes: 4 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import pyarrow as pa
import pyarrow.compute as pc
from huggingface_hub import HfApi, HfFolder
from multiprocess import Pool
from requests import HTTPError

from . import config
Expand Down Expand Up @@ -112,7 +113,7 @@
from .utils.hub import hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.metadata import DatasetMetadata
from .utils.py_utils import Literal, ProcessPool, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.stratify import stratified_shuffle_split_generate_indices
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
from .utils.typing import ListLike, PathLike
Expand Down Expand Up @@ -1504,7 +1505,7 @@ def save_to_disk(
shard_lengths = [None] * num_shards
shard_sizes = [None] * num_shards
if num_proc > 1:
with ProcessPool(num_proc) as pool:
with Pool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job
Expand Down Expand Up @@ -3166,7 +3167,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
logger.info(
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
)
with ProcessPool(len(kwargs_per_job)) as pool:
with Pool(len(kwargs_per_job)) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")
with logging.tqdm(
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import fsspec
import pyarrow as pa
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

from . import config, utils
Expand Down Expand Up @@ -68,7 +69,6 @@
from .utils.filelock import FileLock
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
from .utils.py_utils import (
ProcessPool,
classproperty,
convert_file_size_to_int,
has_sufficient_disk_space,
Expand Down Expand Up @@ -1543,7 +1543,7 @@ def _prepare_split(
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs

with ProcessPool(num_proc) as pool:
with Pool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
Expand Down Expand Up @@ -1802,7 +1802,7 @@ def _prepare_split(
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs

with ProcessPool(num_proc) as pool:
with Pool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
Expand Down
26 changes: 1 addition & 25 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,33 +1330,9 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y
return i


class ProcessPool(multiprocess.pool.Pool):
"""
A multiprocess.pool.Pool implementation that keeps track of child process' PIDs,
and can detect if a child process has been restarted.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._last_pool_pids = self._get_current_pool_pids()
self._has_restarted_subprocess = False

def _get_current_pool_pids(self) -> Set[int]:
return {f.pid for f in self._pool}

def has_restarted_subprocess(self) -> bool:
if self._has_restarted_subprocess:
# If the pool ever restarted a subprocess,
# we don't check the PIDs again.
return True
current_pids = self._get_current_pool_pids()
self._has_restarted_subprocess = current_pids != self._last_pool_pids
self._last_pool_pids = current_pids
return self._has_restarted_subprocess


def iflatmap_unordered(
pool: ProcessPool,
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
func: Callable[..., Iterable[Y]],
*,
kwargs_iterable: Iterable[dict],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,9 @@ def test_map_caching(self, in_memory):
with self._caplog.at_level(WARNING):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with patch(
"datasets.arrow_dataset.ProcessPool",
"datasets.arrow_dataset.Pool",
new_callable=PickableMagicMock,
side_effect=datasets.arrow_dataset.ProcessPool,
side_effect=datasets.arrow_dataset.Pool,
) as mock_pool:
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
dset_test1_data_files = list(dset_test1.cache_files)
Expand Down

0 comments on commit e6a19f3

Please sign in to comment.