Skip to content

Commit

Permalink
[Data] Add option for parallelizing post-collation data batch operati…
Browse files Browse the repository at this point in the history
…ons in `DataIterator.iter_batches()` (#36842)

Currently, the prefetch_batches arg of Dataset.iter_batches is used to configure the number of preloaded batches on both the CPU and GPU; therefore, in the typical case where there is much more CPU than GPU, this constrains the number of batches to prefetch on the CPU.

This PR adds a separate parameter, _finalize_fn, which allows for a user-defined function that is executed in a separate threadpool, which allows for parallelization of these steps. For example, this could be useful for host to device transfers as the last step in getting a batch; this is the default _finalize_fn used when _collate_fn is not specified. Note that when _collate_fn is provided by the user, they should also handle the host to device transfer themselves outside of _collate_fn in order to maximize performance.

---------

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Co-authored-by: amogkam <amogkamsetty@yahoo.com>
  • Loading branch information
scottjlee and amogkam authored Jul 7, 2023
1 parent a6f13e3 commit e55e1fe
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 30 deletions.
32 changes: 21 additions & 11 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
blocks_to_batches,
collate,
extract_data_from_batch,
finalize_batches,
format_batches,
make_async_gen,
resolve_block_refs,
Expand All @@ -30,6 +31,7 @@ def iter_batches(
batch_format: Optional[str] = "default",
drop_last: bool = False,
collate_fn: Optional[Callable[[DataBatch], Any]] = None,
finalize_fn: Optional[Callable[[Any], Any]] = None,
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
Expand All @@ -41,13 +43,12 @@ def iter_batches(
This takes a block iterator and creates batch_size batches, slicing,
unioning, shuffling, prefetching, and formatting blocks as needed.
The algorithm uses both pipeline parallelism and data parallelism:
If prefetch_batches=2, these are all the batches in flight:
[User thread] trains on Batch 0
- [Fetch thread] Batch 1 in output queue
- [Fetch thread] Batch 1 finalization + move to output queue
- [Worker thread 1] Batch 2 formatting + collating
- [Worker thread 2] Batch 3 formatting + collating
- [Raylet] Batches 4 + 5 fetched to local object store memory
Expand All @@ -66,7 +67,8 @@ def iter_batches(
4. Then, in a threadpool consisting of `prefetch_batches` threads:
a. Format the batches to the provided batch format.
b. Apply the collate function.
5. Fetch outputs from the threadpool, maintaining order of the batches.
5. Finalize each of the collated batches
6. Fetch outputs from the threadpool, maintaining order of the batches.
Args:
block_refs: An iterator over block object references and their corresponding
Expand All @@ -86,6 +88,9 @@ def iter_batches(
as batches. Default is "default".
drop_last: Whether to drop the last batch if it's incomplete.
collate_fn: A function to apply to each data batch before returning it.
finalize_fn: A function to apply to each data batch after it has been collated.
This function is not run in a threadpool so it can be used for
memory-intensive operations such as GPU preloading.
shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a
local in-memory shuffle buffer, and this value will serve as the minimum
number of rows that must be in the local in-memory shuffle buffer in order
Expand All @@ -97,8 +102,7 @@ def iter_batches(
process. If set to greater than 0, a separate thread will be used to fetch
the specified amount of formatted batches from blocks. This improves
performance for non-CPU bound UDFs, allowing batch fetching compute and
formatting to be overlapped with the UDF. Defaults to 0 (no prefetching
enabled).
formatting to be overlapped with the UDF. Defaults to 1.
Returns:
An iterator over record batches.
Expand All @@ -119,7 +123,6 @@ def iter_batches(
def _async_iter_batches(
block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
) -> Iterator[DataBatch]:

# Step 1: Prefetch logical batches locally.
block_refs = prefetch_batches_locally(
block_ref_iter=block_refs,
Expand Down Expand Up @@ -152,7 +155,13 @@ def _async_iter_batches(
num_threadpool_workers=prefetch_batches,
)

# Step 5: Restore original order.
# Step 5: Finalize each batch.
if finalize_fn is not None:
batch_iter = finalize_batches(
batch_iter, finalize_fn=finalize_fn, stats=stats
)

# Step 6: Restore original order.
batch_iter: Iterator[Batch] = restore_original_order(batch_iter)

yield from extract_data_from_batch(batch_iter)
Expand Down Expand Up @@ -193,7 +202,7 @@ def _format_in_threadpool(
num_threadpool_workers: The number of threads to use in the threadpool.
"""

def threadpool_computations(
def threadpool_computations_format_collate(
batch_iter: Iterator[Batch],
) -> Iterator[Batch]:
# Step 4a: Format the batches.
Expand All @@ -209,13 +218,14 @@ def threadpool_computations(
yield from formatted_batch_iter

if num_threadpool_workers > 0:
return make_async_gen(
collated_iter = make_async_gen(
base_iterator=batch_iter,
fn=threadpool_computations,
fn=threadpool_computations_format_collate,
num_workers=num_threadpool_workers,
)
else:
return threadpool_computations(batch_iter)
collated_iter = threadpool_computations_format_collate(batch_iter)
return collated_iter


def prefetch_batches_locally(
Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/_internal/block_batching/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,39 @@ def collate(
Args:
batch_iter: An iterator over formatted batches.
collate_fn: A function to apply to each batch.
stats: An optional stats object to record formatting times.
"""
for batch in batch_iter:
with stats.iter_collate_batch_s.timer() if stats else nullcontext():
collated_batch = collate_fn(batch.data)
yield CollatedBatch(batch.batch_idx, collated_batch)


def finalize_batches(
batch_iter: Iterator[CollatedBatch],
finalize_fn: Callable[[Any], Any],
stats: Optional[DatasetStats] = None,
) -> Iterator[CollatedBatch]:
"""Returns an iterator with the provided finalize_fn applied to items of the batch
iterator.
This is the same as `collate` except the input batches can be of type Any.
Args:
batch_iter: An iterator over processed batches.
finalize_fn: A function to apply to each batch.
stats: An optional stats object to record formatting times.
Returns:
An iterator over batch index and the finalized batch.
"""
for batch in batch_iter:
with stats.iter_finalize_batch_s.timer() if stats else nullcontext():
finalized_batch = finalize_fn(batch.data)
yield CollatedBatch(batch.batch_idx, finalized_batch)


def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]:
for batch in batch_iter:
yield batch.data
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/iterator/pipelined_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def iter_batches(
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
_finalize_fn: Optional[Callable[[Any], Any]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> Iterator[DataBatch]:
Expand All @@ -79,6 +80,7 @@ def iter_batches(
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
_collate_fn=_collate_fn,
_finalize_fn=_finalize_fn,
prefetch_blocks=prefetch_blocks,
)

Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def __init__(
self.iter_next_batch_s: Timer = Timer()
self.iter_format_batch_s: Timer = Timer()
self.iter_collate_batch_s: Timer = Timer()
self.iter_finalize_batch_s: Timer = Timer()
self.iter_total_blocked_s: Timer = Timer()
self.iter_user_s: Timer = Timer()
self.iter_total_s: Timer = Timer()
Expand Down Expand Up @@ -320,6 +321,7 @@ def to_summary(self) -> "DatasetStatsSummary":
self.iter_next_batch_s,
self.iter_format_batch_s,
self.iter_collate_batch_s,
self.iter_finalize_batch_s,
self.iter_total_blocked_s,
self.iter_user_s,
self.iter_total_s,
Expand Down Expand Up @@ -726,6 +728,8 @@ class IterStatsSummary:
format_time: Timer
# Time spent in collate fn, in seconds
collate_time: Timer
# Time spent in finalize_fn, in seconds
finalize_batch_time: Timer
# Total time user thread is blocked by iter_batches
block_time: Timer
# Time spent in user code, in seconds
Expand Down Expand Up @@ -754,6 +758,7 @@ def to_string(self) -> str:
or self.next_time.get()
or self.format_time.get()
or self.collate_time.get()
or self.finalize_batch_time.get()
):
out += "\nDataset iterator time breakdown:\n"
if self.block_time.get():
Expand Down Expand Up @@ -808,6 +813,16 @@ def to_string(self) -> str:
fmt(self.collate_time.avg()),
fmt(self.collate_time.get()),
)
if self.finalize_batch_time.get():
format_str = (
" * In host->device transfer: {} min, {} max, {} avg, {} total\n"
)
out += format_str.format(
fmt(self.finalize_batch_time.min()),
fmt(self.finalize_batch_time.max()),
fmt(self.finalize_batch_time.avg()),
fmt(self.finalize_batch_time.get()),
)

return out

Expand Down Expand Up @@ -875,6 +890,7 @@ def __init__(self, *, max_history: int = 3):
"iter_next_batch_s": Timer(),
"iter_format_batch_s": Timer(),
"iter_collate_batch_s": Timer(),
"iter_finalize_batch_s": Timer(),
"iter_user_s": Timer(),
"iter_total_s": Timer(),
}
Expand Down
26 changes: 21 additions & 5 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,13 +1112,29 @@ def iter_torch_batches(
:py:meth:`Dataset.iter_torch_batches
<ray.data.Dataset.iter_torch_batches>` over the stream of output batches
from the pipeline."""
return DataIterator.iter_torch_batches(
self,

from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
)

if collate_fn is not None and (dtypes is not None or device is not None):
raise ValueError(
"collate_fn cannot be used with dtypes and device. It is expected that"
"the provided `collate_fn` will move the output Torch tensors to the"
"appropriate dtype and device."
)

if collate_fn is None:

def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]):
return convert_ndarray_batch_to_torch_tensor_batch(
batch, dtypes=dtypes, device=device
)

return self.iter_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
dtypes=dtypes,
device=device,
collate_fn=collate_fn,
_collate_fn=collate_fn,
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
Expand Down
58 changes: 45 additions & 13 deletions python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def iter_batches(
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
_finalize_fn: Optional[Callable[[Any], Any]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> Iterator[DataBatch]:
Expand Down Expand Up @@ -163,6 +164,15 @@ def drop_metadata(block_iterator):
for block_ref, metadata in block_iterator:
yield block_ref

# Legacy iter_batches does not have a distinction between
# collate_fn and finalize_fn since batches are not prefetched.
def collate_and_finalize(batch):
if _collate_fn is not None:
batch = _collate_fn(batch)
if _finalize_fn is not None:
batch = _finalize_fn(batch)
return batch

yield from batch_block_refs(
drop_metadata(block_iterator),
stats=stats,
Expand All @@ -171,7 +181,7 @@ def drop_metadata(block_iterator):
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
collate_fn=_collate_fn,
collate_fn=collate_and_finalize,
shuffle_buffer_min_size=local_shuffle_buffer_size,
shuffle_seed=local_shuffle_seed,
)
Expand All @@ -184,6 +194,7 @@ def drop_metadata(block_iterator):
batch_format=batch_format,
drop_last=drop_last,
collate_fn=_collate_fn,
finalize_fn=_finalize_fn,
shuffle_buffer_min_size=local_shuffle_buffer_size,
shuffle_seed=local_shuffle_seed,
prefetch_batches=prefetch_batches,
Expand Down Expand Up @@ -285,13 +296,14 @@ def iter_torch_batches(
will be inferred from the tensor data.
device: The device on which the tensor should be placed; if None, the Torch
tensor will be constructed on the CPU.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
Potential use cases include collating along a dimension other than the
first, padding sequences of various lengths, or generally handling
batches of different length tensors. If not provided, the default
collate function is used which simply converts the batch of numpy
arrays to a batch of PyTorch tensors. This API is still experimental
and is subject to change.
collate_fn: A function to apply to each data batch before returning it. When
this parameter is specified, the user should manually handle the host
to device data transfer outside of collate_fn. Potential use cases
include collating along a dimension other than the first, padding
sequences of various lengths, or generally handling batches of different
length tensors. This API is still experimental and is subject to change.
This parameter cannot be used in conjunction with ``dtypes`` or
``device``.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
Expand All @@ -314,24 +326,43 @@ def iter_torch_batches(

if collate_fn is not None and (dtypes is not None or device is not None):
raise ValueError(
"collate_fn cannot be used with dtypes and device. It is expected that"
"the provided `collate_fn` will move the output Torch tensors to the"
"appropriate dtype and device."
"collate_fn cannot be used with dtypes and device."
"You should manually move the output Torch tensors to the"
"desired dtype and device outside of collate_fn."
)

if collate_fn is None:

# Automatically move torch tensors to the appropriate device.
if device is None:
default_device = get_device()
if default_device.type != "cpu":
device = default_device

# The default collate_fn handles formatting and Tensor creation.
# Here, we set device=None to defer host to device data transfer
# to the subsequent finalize_fn.
def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]):
return convert_ndarray_batch_to_torch_tensor_batch(
batch, dtypes=dtypes, device=device
batch,
dtypes=dtypes,
device=None,
)

# The default finalize_fn handles the host to device data transfer.
# This is executed in a 1-thread pool separately from collate_fn
# to allow independent parallelism of these steps.
def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]):
if device is not None:
if isinstance(batch, dict):
for k, t in batch.items():
batch[k] = t.to(device=device)
else:
batch = batch.to(device=device)
return batch

else:
finalize_fn = None

yield from self.iter_batches(
prefetch_batches=prefetch_batches,
prefetch_blocks=prefetch_blocks,
Expand All @@ -340,6 +371,7 @@ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]):
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
_collate_fn=collate_fn,
_finalize_fn=finalize_fn,
)

def iter_tf_batches(
Expand Down
Loading

0 comments on commit e55e1fe

Please sign in to comment.