Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded save_to_disk + multiprocessing #5268

Merged
merged 31 commits into from
Dec 14, 2022
Merged

Sharded save_to_disk + multiprocessing #5268

merged 31 commits into from
Dec 14, 2022

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Nov 18, 2022

Added num_shards= and num_proc= to save_to_disk()

EDIT: also added max_shard_size= to save_to_disk(), and also num_shards= to push_to_hub

I also:

  • deprecated the fs parameter in favor of storage_options (for consistency with the rest of the lib) in save_to_disk and load_from_disk
  • always embed the image/audio data in arrow when doing save_to_disk
  • added a tqdm bar in save_to_disk
  • Use the MockFileSystem in tests for save_to_disk and load_from_disk
  • removed the unused integration tests with S3, since we can now test with mockfs instead of s3fs

TODO:

  • implem save_to_disk for dataset dict
  • save_to_disk for dataset dict tests
  • deprecate fs in dataset dict load_from_disk as well
  • update docs

Close #5263
Close #4196
Close #4351

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 18, 2022

The documentation is not available anymore as the PR was closed or merged.

@lhoestq lhoestq marked this pull request as ready for review November 21, 2022 18:04
Comment on lines 1350 to 1366
if config.PYARROW_VERSION.major >= 8:
for pa_table in table_iter(shard.data.table, batch_size=batch_size):
writer.write_table(pa_table)
num_examples_progress_update += len(pa_table)
if time.time() > _time + refresh_rate:
_time = time.time()
yield job_id, False, num_examples_progress_update
num_examples_progress_update = 0
else:
for i in range(0, shard.num_rows, batch_size):
pa_table = shard.data.slice(i, batch_size)
writer.write_table(pa_table)
num_examples_progress_update += len(pa_table)
if time.time() > _time + refresh_rate:
_time = time.time()
yield job_id, False, num_examples_progress_update
num_examples_progress_update = 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I iterate on batches here to update the tqdm bar, but for old versions of pyarrow this may be too slow since table_iter only works for pyarrow>=8.

I think we may have to implement table_iter even on old versions for performance reasons. It can be based on pa.Table.to_record_batches - lmk what you think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I just implemented pa.Table.to_reader for pyarrow < 8 for our datasets.table.Table. This way we don't have to check the pyarrow version anymore

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job!

docs/source/filesystems.mdx Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
@lhoestq
Copy link
Member Author

lhoestq commented Dec 8, 2022

Added both num_shards and max_shard_size in push_to_hub/save_to_disk. Will take care of updating the tests later

@lhoestq
Copy link
Member Author

lhoestq commented Dec 12, 2022

It's ready for a final review @mariosasko and @albertvillanova, let me know what you think :)

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits.

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/dataset_dict.py Outdated Show resolved Hide resolved
src/datasets/dataset_dict.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member Author

lhoestq commented Dec 14, 2022

Took your comments into account, and also changed iflatmap_unordered to take an iterable of kwargs to make the code more redable :)

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@lhoestq lhoestq merged commit 232a439 into main Dec 14, 2022
@lhoestq lhoestq deleted the sharded-save_to_disk branch December 14, 2022 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants