-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve skip take shuffling and distributed #6965
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing this so fast! Do you plan to land it soon?
distributed_dataset = split_dataset_by_node(distributed_dataset, rank=rank, world_size=world_size) | ||
distributed_dataset = distributed_dataset.skip(count) if method == "skip" else distributed_dataset.take(count) | ||
assert ( | ||
list(true_distributed_dataset)[count:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks good to me!
distributed_dataset = distributed_dataset.skip(count) if method == "skip" else distributed_dataset.take(count) | ||
distributed_dataset = split_dataset_by_node(distributed_dataset, rank=rank, world_size=world_size) | ||
assert len( | ||
list(true_distributed_dataset)[count // world_size :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know the implementation of split_dataset_by_node
very well, just trying to raise a concern on this test case and please correct me if I'm wrong.
Assume world_size = 2, count = 1. Assume underlying data is [0, 1, 2, ..., 9].
true_distributed_dataset
on rank 0 would be [0, 2, 4, 6, 8]
true_distributed_dataset
on rank 1 would be [1, 3, 5, 7, 9]
and after calling skip
and split_dataset_by_node
distributed_dataset
on rank 0 would be [1, 3, 5, 7, 9]
distributed_dataset
on rank 1 would be [2, 4, 6, 8]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct. The test doesn't ensure that some examples are not skipped if count
is not a factor of world_size
at the moment, bu we can improve that later
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
set the right behavior of skip/take depending on whether it's called after or before shuffle/split_by_node