-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IterableDataset Arrow formatting #5821
Conversation
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
The documentation is not available anymore as the PR was closed or merged. |
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
Will need to take #5810 into account if it gets merged before this one |
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the enhancement!!
Some comments below...
ex_iterable = ExamplesIterable(Dataset._generate_examples_from_shards, kwargs={"shards": shards}) | ||
ex_iterable = ArrowExamplesIterable( | ||
Dataset._generate_tables_from_shards, | ||
kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we should support users to pass a custom batch_size
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I'm not sure - we can wait for some feedback on this and improve later imo
src/datasets/iterable_dataset.py
Outdated
def _convert_to_arrow( | ||
iterable: Iterable[Tuple[Key, dict]], | ||
batch_size: int, | ||
drop_last_batch=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing type hint here.
src/datasets/iterable_dataset.py
Outdated
batch_size: int, | ||
drop_last_batch=False, | ||
) -> Iterator[Tuple[Key, pa.Table]]: | ||
"""Iterate over sub-tables of size `batch_size`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition in the docstring is the same as the one for _batch_arrow_tables
below. Maybe we should add to the description they expect different iterables...
Maybe also the naming of both functions could be aligned: now they have very different names whereas their functionality is analogous...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having different names because one is expensive (contains "convert" in the name) while the other one is not (simply batches data)
I fixed the docstring and type hint |
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
let me know if it sounds good for you now @albertvillanova :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could update the IterableDataset.with_format
docstring, now thta it supports the "arrow" format as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment about _batch_arrow_tables
below.
keys_buffer = [] | ||
chunks_buffer = [] | ||
chunks_buffer_size = 0 | ||
for key, pa_table in iterable: | ||
for chunk in pa_table.to_reader(max_chunksize=batch_size): | ||
if len(chunk) == 0: | ||
continue | ||
elif chunks_buffer_size + len(chunk) < batch_size: | ||
keys_buffer.append(key) | ||
chunks_buffer.append(chunk) | ||
chunks_buffer_size += len(chunk) | ||
continue | ||
elif chunks_buffer_size + len(chunk) == batch_size: | ||
keys_buffer.append(key) | ||
chunks_buffer.append(chunk) | ||
new_key = "_".join(str(_key) for _key in keys_buffer) | ||
yield new_key, pa.Table.from_batches(chunks_buffer) | ||
keys_buffer = [] | ||
chunks_buffer = [] | ||
chunks_buffer_size = 0 | ||
else: | ||
cropped_chunk_length = batch_size - chunks_buffer_size | ||
keys_buffer.append(f"{key}[:{cropped_chunk_length}]") | ||
chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) | ||
new_key = "_".join(str(_key) for _key in keys_buffer) | ||
yield new_key, pa.Table.from_batches(chunks_buffer) | ||
keys_buffer = [f"{key}[{cropped_chunk_length}:]"] | ||
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] | ||
chunks_buffer_size = len(chunk) - cropped_chunk_length | ||
if not drop_last_batch and chunks_buffer: | ||
new_key = "_".join(str(_key) for _key in keys_buffer) | ||
yield new_key, pa.Table.from_batches(chunks_buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function just rebatches, from the original batch size returned by self.generate_tables_fn(**self.kwargs)
, to the target batch_size
. And to do that, it generates batches from Tables (using Table.to_reader
) and then Tables from batches (using Table.from_batches
).
I am just wondering it the implementation could be simpler... Not sure though.
What do you think about a naive approach like?
iterator = iter(pa_table.slice(i, 1) for _, pa_table in iterable for i in range(pa_table.num_rows))
batch = True
key = 0
while batch:
batch = list(islice(iterator, batch_size))
if batch:
if drop_last_batch and len(batch) < batch_size:
continue
yield key, pa.concat_tables(batch)
key += 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think iterating on batches of size 1 and grouping them again is slower. It may also return tables with too many record batches (1 batch per row) to be efficiently used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, feel free to merge as it is...
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After your commit 028822a, you say with_format
only supports "arrow" or None.
Maybe we should fix all the tests in test_iterable_dataset.py
that contain .with_format("torch")
?
Otherwise, feel free to merge as it is.
they're updated in #5852 |
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
Adding an optional
.iter_arrow
to examples iterable. This allows to use Arrow formatting in map/filter.This will also be useful for torch formatting, since we can reuse the TorchFormatter that converts Arrow data to torch tensors
Related to #5793 and #3444
Required for #5852
Example:
Speed x10 in map
Implementation details
I added an optional
iter_arrow
method to examples iterable. If an example iterable has this method, then it can be used to iterate on the examples by batch of arrow tables.