-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make IterableDataset.from_spark more efficient #5986
Conversation
…refetching of next partition. Also reordered the spark dataframe to be in the order it will be traversed, allowing prefetching to work better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
for row in rows: | ||
yield f"{partition_id}_{row_id}", row.asDict() | ||
row_id += 1 | ||
partition_df, size_of_partitions = reorder_dataframe_by_partition(df_with_partition_id, partition_order) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of keeping track of partition sizes, it may be cleaner to just keep the part_id column but delete it from the row dict map before yielding it
… know which partition we are in, simply don't drop the part_id column, convert to pandas dataframe, and use that info
@@ -6,6 +6,7 @@ | |||
|
|||
import numpy as np | |||
import pyarrow as pa | |||
import pyspark |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may not import pyspark from here, since it will make pyspark as a dependency for datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we put it inside the generator then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of the added comments can be removed, I think the code is pretty self-explanatory
@@ -31,21 +32,37 @@ class SparkConfig(datasets.BuilderConfig): | |||
features: Optional[datasets.Features] = None | |||
|
|||
|
|||
def reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a leading underscore to indicate that this function shouldn't be called from outside this file (_ reorder_dataframe_by_partition)
def _generate_iterable_examples( | ||
df: "pyspark.sql.DataFrame", | ||
partition_order: List[int], | ||
): | ||
import pyspark | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add this back
row_id += 1 | ||
partition_df = _reorder_dataframe_by_partition(df_with_partition_id, partition_order) | ||
row_id = 0 | ||
# pipeline partitions to hide latency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about "Prefetch partitions in parallel"
row_id = 0 | ||
# pipeline partitions to hide latency | ||
rows = partition_df.toLocalIterator(prefetchPartitions=True) | ||
last_partition = -1 # keep track of the last partition so that we can know when to reset row_id = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename this variable to be "current_partition". Also you can remove the comment here
row_as_dict = row.asDict() | ||
part_id = row_as_dict['part_id'] | ||
row_as_dict.pop('part_id') | ||
if last_partition != part_id: # we are on new partition, reset row_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the comment here
@lhoestq would you be able to review this please and also approve the workflow? |
Sounds good to me :) feel free to run |
The documentation is not available anymore as the PR was closed or merged. |
cool ! I think we can merge once all comments have been addressed |
@lhoestq I just addressed the comments and I think we can move ahead with this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect ! :)
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
Moved the code from using collect() to using toLocalIterator, which allows for prefetching partitions that will be selected next, thus allowing for better performance when iterating.