Skip to content

Commit

Permalink
Merge branch 'tfp/avoid-stuck-map-operation' of https://github.com/pa…
Browse files Browse the repository at this point in the history
…ppacena/datasets into tfp/avoid-stuck-map-operation
  • Loading branch information
pappacena committed Jul 5, 2023
2 parents 6bb15c1 + 45a5ee0 commit 285603d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
7 changes: 3 additions & 4 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import pyarrow as pa
import pyarrow.compute as pc
from huggingface_hub import HfApi, HfFolder
from multiprocess import Pool
from requests import HTTPError

from . import config
Expand Down Expand Up @@ -113,7 +112,7 @@
from .utils.hub import hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.metadata import DatasetMetadata
from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.py_utils import Literal, ProcessPool, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.stratify import stratified_shuffle_split_generate_indices
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
from .utils.typing import ListLike, PathLike
Expand Down Expand Up @@ -1505,7 +1504,7 @@ def save_to_disk(
shard_lengths = [None] * num_shards
shard_sizes = [None] * num_shards
if num_proc > 1:
with Pool(num_proc) as pool:
with ProcessPool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job
Expand Down Expand Up @@ -3167,7 +3166,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
logger.info(
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
)
with Pool(len(kwargs_per_job)) as pool:
with ProcessPool(len(kwargs_per_job)) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")
with logging.tqdm(
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import fsspec
import pyarrow as pa
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

from . import config, utils
Expand Down Expand Up @@ -69,6 +68,7 @@
from .utils.filelock import FileLock
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
from .utils.py_utils import (
ProcessPool,
classproperty,
convert_file_size_to_int,
has_sufficient_disk_space,
Expand Down Expand Up @@ -1543,7 +1543,7 @@ def _prepare_split(
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs

with Pool(num_proc) as pool:
with ProcessPool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
Expand Down Expand Up @@ -1802,7 +1802,7 @@ def _prepare_split(
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs

with Pool(num_proc) as pool:
with ProcessPool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
Expand Down
4 changes: 2 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,9 @@ def test_map_caching(self, in_memory):
with self._caplog.at_level(WARNING):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with patch(
"datasets.arrow_dataset.Pool",
"datasets.arrow_dataset.ProcessPool",
new_callable=PickableMagicMock,
side_effect=datasets.arrow_dataset.Pool,
side_effect=datasets.arrow_dataset.ProcessPool,
) as mock_pool:
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
dset_test1_data_files = list(dset_test1.cache_files)
Expand Down

0 comments on commit 285603d

Please sign in to comment.