Skip to content

Commit

Permalink
use kwargs_iterable in iflatmap_unordered
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Dec 14, 2022
1 parent c2b38fa commit ce66732
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 36 deletions.
16 changes: 7 additions & 9 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ def save_to_disk(
leave=False,
desc=f"Saving the dataset ({shards_done}/{num_shards} shards)",
)
args_per_job = (
kwargs_per_job = (
{
"job_id": shard_idx,
"shard": dataset.shard(num_shards=num_shards, index=shard_idx, contiguous=True),
Expand All @@ -1403,7 +1403,9 @@ def save_to_disk(
shard_sizes = [None] * num_shards
if num_proc > 1:
with Pool(num_proc) as pool:
for job_id, done, content in iflatmap_unordered(pool, Dataset._save_to_disk_single, args_per_job):
for job_id, done, content in iflatmap_unordered(
pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job
):
if done:
shards_done += 1
pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)")
Expand All @@ -1412,8 +1414,8 @@ def save_to_disk(
else:
pbar.update(content)
else:
for args in args_per_job:
for job_id, done, content in Dataset._save_to_disk_single(args):
for kwargs in kwargs_per_job:
for job_id, done, content in Dataset._save_to_disk_single(**kwargs):
if done:
shards_done += 1
pbar.set_description(f"Saving the dataset ({shards_done}/{num_shards} shards)")
Expand All @@ -1431,11 +1433,7 @@ def save_to_disk(
json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2)

@staticmethod
def _save_to_disk_single(arg):
job_id: Dataset = arg["job_id"]
shard: Dataset = arg["shard"]
fpath: str = arg["fpath"]
storage_options: Optional[dict] = arg["storage_options"]
def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_options: Optional[dict]):
batch_size = config.DEFAULT_MAX_BATCH_SIZE

if shard._indices is not None:
Expand Down
37 changes: 19 additions & 18 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def _prepare_split(
gen_kwargs = split_generator.gen_kwargs
job_id = 0
for job_id, done, content in self._prepare_split_single(
{"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
):
if done:
result = content
Expand All @@ -1459,13 +1459,13 @@ def _prepare_split(
[item] for item in result
]
else:
args_per_job = [
kwargs_per_job = [
{"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
for job_id, gen_kwargs in enumerate(
_split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
)
]
num_jobs = len(args_per_job)
num_jobs = len(kwargs_per_job)

examples_per_job = [None] * num_jobs
bytes_per_job = [None] * num_jobs
Expand All @@ -1474,7 +1474,9 @@ def _prepare_split(
shard_lengths_per_job = [None] * num_jobs

with Pool(num_proc) as pool:
for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job):
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
):
if done:
# the content is the result of the job
(
Expand Down Expand Up @@ -1534,14 +1536,16 @@ def _rename_shard(shard_and_job: Tuple[int]):
if self.info.features is None:
self.info.features = features

def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
gen_kwargs: dict = arg["gen_kwargs"]
fpath: str = arg["fpath"]
file_format: str = arg["file_format"]
max_shard_size: int = arg["max_shard_size"]
split_info: SplitInfo = arg["split_info"]
check_duplicate_keys: bool = arg["check_duplicate_keys"]
job_id: int = arg["job_id"]
def _prepare_split_single(
self,
gen_kwargs: dict,
fpath: str,
file_format: str,
max_shard_size: int,
split_info: SplitInfo,
check_duplicate_keys: bool,
job_id: int,
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:

generator = self._generate_examples(**gen_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
Expand Down Expand Up @@ -1788,12 +1792,9 @@ def _rename_shard(shard_id_and_job: Tuple[int]):
if self.info.features is None:
self.info.features = features

def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
gen_kwargs: dict = arg["gen_kwargs"]
fpath: str = arg["fpath"]
file_format: str = arg["file_format"]
max_shard_size: int = arg["max_shard_size"]
job_id: int = arg["job_id"]
def _prepare_split_single(
self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:

generator = self._generate_tables(**gen_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
Expand Down
14 changes: 8 additions & 6 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,25 +1335,27 @@ def copyfunc(func):
return result


X = TypeVar("X")
Y = TypeVar("Y")


def _write_generator_to_queue(queue: queue.Queue, func: Callable[[X], Iterable[Y]], arg: X) -> int:
for i, result in enumerate(func(arg)):
def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y]], kwargs: dict) -> int:
for i, result in enumerate(func(**kwargs)):
queue.put(result)
return i


def iflatmap_unordered(
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
func: Callable[[X], Iterable[Y]],
iterable: Iterable[X],
func: Callable[..., Iterable[Y]],
*,
kwargs_iterable: Iterable[dict],
) -> Iterable[Y]:
manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager
with manager_cls() as manager:
queue = manager.Queue()
async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, arg)) for arg in iterable]
async_results = [
pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable
]
while True:
try:
yield queue.get(timeout=0.05)
Expand Down
12 changes: 9 additions & 3 deletions tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def test_asdict():
asdict([1, A(x=10, y="foo")])


def _split_text(text: str):
return text.split()


def _2seconds_generator_of_2items_with_timing(content):
yield (time.time(), content)
time.sleep(2)
Expand All @@ -249,22 +253,24 @@ def _2seconds_generator_of_2items_with_timing(content):
def test_iflatmap_unordered():

with Pool(2) as pool:
out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10))
out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
assert out.count("hello") == 10
assert out.count("there") == 10
assert len(out) == 20

# check multiprocess from pathos (uses dill for pickling)
with multiprocess.Pool(2) as pool:
out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10))
out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
assert out.count("hello") == 10
assert out.count("there") == 10
assert len(out) == 20

# check that we get items as fast as possible
with Pool(2) as pool:
out = []
for yield_time, content in iflatmap_unordered(pool, _2seconds_generator_of_2items_with_timing, ["a", "b"]):
for yield_time, content in iflatmap_unordered(
pool, _2seconds_generator_of_2items_with_timing, kwargs_iterable=[{"content": "a"}, {"content": "b"}]
):
assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded"
out.append(content)
assert out.count("a") == 2
Expand Down

1 comment on commit ce66732

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010602 / 0.011353 (-0.000751) 0.005754 / 0.011008 (-0.005255) 0.112711 / 0.038508 (0.074203) 0.040895 / 0.023109 (0.017786) 0.346258 / 0.275898 (0.070360) 0.415277 / 0.323480 (0.091797) 0.009021 / 0.007986 (0.001035) 0.004532 / 0.004328 (0.000203) 0.088894 / 0.004250 (0.084644) 0.049800 / 0.037052 (0.012747) 0.371306 / 0.258489 (0.112817) 0.400129 / 0.293841 (0.106288) 0.043744 / 0.128546 (-0.084803) 0.013864 / 0.075646 (-0.061783) 0.383952 / 0.419271 (-0.035319) 0.054585 / 0.043533 (0.011052) 0.342040 / 0.255139 (0.086901) 0.369239 / 0.283200 (0.086039) 0.118985 / 0.141683 (-0.022698) 1.729818 / 1.452155 (0.277663) 1.774670 / 1.492716 (0.281954)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.231819 / 0.018006 (0.213813) 0.490665 / 0.000490 (0.490176) 0.001285 / 0.000200 (0.001085) 0.000104 / 0.000054 (0.000050)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.032319 / 0.037411 (-0.005092) 0.126846 / 0.014526 (0.112320) 0.136881 / 0.176557 (-0.039676) 0.181927 / 0.737135 (-0.555209) 0.143070 / 0.296338 (-0.153269)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.458465 / 0.215209 (0.243256) 4.597164 / 2.077655 (2.519509) 2.114882 / 1.504120 (0.610762) 1.872692 / 1.541195 (0.331497) 1.935434 / 1.468490 (0.466943) 0.794945 / 4.584777 (-3.789832) 4.448981 / 3.745712 (0.703269) 2.622075 / 5.269862 (-2.647787) 1.781003 / 4.565676 (-2.784674) 0.103423 / 0.424275 (-0.320852) 0.014736 / 0.007607 (0.007129) 0.616522 / 0.226044 (0.390478) 6.109465 / 2.268929 (3.840537) 2.738392 / 55.444624 (-52.706233) 2.324060 / 6.876477 (-4.552416) 2.466920 / 2.142072 (0.324848) 1.053398 / 4.805227 (-3.751829) 0.206516 / 6.500664 (-6.294148) 0.076058 / 0.075469 (0.000589)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.475212 / 1.841788 (-0.366576) 17.670973 / 8.074308 (9.596665) 16.478182 / 10.191392 (6.286790) 0.230169 / 0.680424 (-0.450255) 0.034150 / 0.534201 (-0.500051) 0.528322 / 0.579283 (-0.050961) 0.537326 / 0.434364 (0.102962) 0.647465 / 0.540337 (0.107127) 0.784431 / 1.386936 (-0.602505)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008895 / 0.011353 (-0.002458) 0.005864 / 0.011008 (-0.005144) 0.108831 / 0.038508 (0.070323) 0.046575 / 0.023109 (0.023465) 0.389691 / 0.275898 (0.113793) 0.446171 / 0.323480 (0.122691) 0.008526 / 0.007986 (0.000540) 0.004535 / 0.004328 (0.000207) 0.083655 / 0.004250 (0.079405) 0.046459 / 0.037052 (0.009407) 0.390635 / 0.258489 (0.132146) 0.469813 / 0.293841 (0.175972) 0.043466 / 0.128546 (-0.085080) 0.013730 / 0.075646 (-0.061916) 0.390086 / 0.419271 (-0.029185) 0.063999 / 0.043533 (0.020466) 0.394216 / 0.255139 (0.139077) 0.442514 / 0.283200 (0.159314) 0.126921 / 0.141683 (-0.014762) 1.692541 / 1.452155 (0.240386) 1.884501 / 1.492716 (0.391784)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.253914 / 0.018006 (0.235907) 0.492270 / 0.000490 (0.491780) 0.004290 / 0.000200 (0.004090) 0.000106 / 0.000054 (0.000052)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.032471 / 0.037411 (-0.004940) 0.130526 / 0.014526 (0.116000) 0.147004 / 0.176557 (-0.029552) 0.197063 / 0.737135 (-0.540073) 0.148557 / 0.296338 (-0.147782)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.527672 / 0.215209 (0.312462) 5.241128 / 2.077655 (3.163473) 2.672750 / 1.504120 (1.168630) 2.404881 / 1.541195 (0.863687) 2.394574 / 1.468490 (0.926084) 0.896397 / 4.584777 (-3.688380) 4.630215 / 3.745712 (0.884503) 4.575943 / 5.269862 (-0.693918) 1.906801 / 4.565676 (-2.658875) 0.106931 / 0.424275 (-0.317344) 0.022814 / 0.007607 (0.015207) 0.676553 / 0.226044 (0.450508) 6.655181 / 2.268929 (4.386253) 3.324338 / 55.444624 (-52.120286) 2.850195 / 6.876477 (-4.026282) 2.941465 / 2.142072 (0.799393) 1.076295 / 4.805227 (-3.728933) 0.207705 / 6.500664 (-6.292959) 0.073119 / 0.075469 (-0.002351)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.462291 / 1.841788 (-0.379497) 17.777891 / 8.074308 (9.703583) 16.104116 / 10.191392 (5.912724) 0.171357 / 0.680424 (-0.509067) 0.020802 / 0.534201 (-0.513399) 0.508140 / 0.579283 (-0.071143) 0.511213 / 0.434364 (0.076849) 0.658604 / 0.540337 (0.118266) 0.788271 / 1.386936 (-0.598665)

Please sign in to comment.