From e55e1fe8066c37e44a0922b67f0efe6e67c37c4e Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 6 Jul 2023 18:39:47 -0700 Subject: [PATCH] [Data] Add option for parallelizing post-collation data batch operations in `DataIterator.iter_batches()` (#36842) Currently, the prefetch_batches arg of Dataset.iter_batches is used to configure the number of preloaded batches on both the CPU and GPU; therefore, in the typical case where there is much more CPU than GPU, this constrains the number of batches to prefetch on the CPU. This PR adds a separate parameter, _finalize_fn, which allows for a user-defined function that is executed in a separate threadpool, which allows for parallelization of these steps. For example, this could be useful for host to device transfers as the last step in getting a batch; this is the default _finalize_fn used when _collate_fn is not specified. Note that when _collate_fn is provided by the user, they should also handle the host to device transfer themselves outside of _collate_fn in order to maximize performance. --------- Signed-off-by: Scott Lee Signed-off-by: amogkam Co-authored-by: amogkam --- .../_internal/block_batching/iter_batches.py | 32 ++++++---- .../ray/data/_internal/block_batching/util.py | 26 +++++++++ .../_internal/iterator/pipelined_iterator.py | 2 + python/ray/data/_internal/stats.py | 16 +++++ python/ray/data/dataset_pipeline.py | 26 +++++++-- python/ray/data/iterator.py | 58 ++++++++++++++----- .../tests/block_batching/test_iter_batches.py | 40 +++++++++++++ .../data/tests/block_batching/test_util.py | 16 +++++ python/ray/data/tests/test_iterator.py | 23 +++++++- 9 files changed, 209 insertions(+), 30 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index fd017dd55fa7..311e89d5bae9 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -10,6 +10,7 @@ blocks_to_batches, collate, extract_data_from_batch, + finalize_batches, format_batches, make_async_gen, resolve_block_refs, @@ -30,6 +31,7 @@ def iter_batches( batch_format: Optional[str] = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], Any]] = None, + finalize_fn: Optional[Callable[[Any], Any]] = None, shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, @@ -41,13 +43,12 @@ def iter_batches( This takes a block iterator and creates batch_size batches, slicing, unioning, shuffling, prefetching, and formatting blocks as needed. - The algorithm uses both pipeline parallelism and data parallelism: If prefetch_batches=2, these are all the batches in flight: [User thread] trains on Batch 0 - - [Fetch thread] Batch 1 in output queue + - [Fetch thread] Batch 1 finalization + move to output queue - [Worker thread 1] Batch 2 formatting + collating - [Worker thread 2] Batch 3 formatting + collating - [Raylet] Batches 4 + 5 fetched to local object store memory @@ -66,7 +67,8 @@ def iter_batches( 4. Then, in a threadpool consisting of `prefetch_batches` threads: a. Format the batches to the provided batch format. b. Apply the collate function. - 5. Fetch outputs from the threadpool, maintaining order of the batches. + 5. Finalize each of the collated batches + 6. Fetch outputs from the threadpool, maintaining order of the batches. Args: block_refs: An iterator over block object references and their corresponding @@ -86,6 +88,9 @@ def iter_batches( as batches. Default is "default". drop_last: Whether to drop the last batch if it's incomplete. collate_fn: A function to apply to each data batch before returning it. + finalize_fn: A function to apply to each data batch after it has been collated. + This function is not run in a threadpool so it can be used for + memory-intensive operations such as GPU preloading. shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the minimum number of rows that must be in the local in-memory shuffle buffer in order @@ -97,8 +102,7 @@ def iter_batches( process. If set to greater than 0, a separate thread will be used to fetch the specified amount of formatted batches from blocks. This improves performance for non-CPU bound UDFs, allowing batch fetching compute and - formatting to be overlapped with the UDF. Defaults to 0 (no prefetching - enabled). + formatting to be overlapped with the UDF. Defaults to 1. Returns: An iterator over record batches. @@ -119,7 +123,6 @@ def iter_batches( def _async_iter_batches( block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], ) -> Iterator[DataBatch]: - # Step 1: Prefetch logical batches locally. block_refs = prefetch_batches_locally( block_ref_iter=block_refs, @@ -152,7 +155,13 @@ def _async_iter_batches( num_threadpool_workers=prefetch_batches, ) - # Step 5: Restore original order. + # Step 5: Finalize each batch. + if finalize_fn is not None: + batch_iter = finalize_batches( + batch_iter, finalize_fn=finalize_fn, stats=stats + ) + + # Step 6: Restore original order. batch_iter: Iterator[Batch] = restore_original_order(batch_iter) yield from extract_data_from_batch(batch_iter) @@ -193,7 +202,7 @@ def _format_in_threadpool( num_threadpool_workers: The number of threads to use in the threadpool. """ - def threadpool_computations( + def threadpool_computations_format_collate( batch_iter: Iterator[Batch], ) -> Iterator[Batch]: # Step 4a: Format the batches. @@ -209,13 +218,14 @@ def threadpool_computations( yield from formatted_batch_iter if num_threadpool_workers > 0: - return make_async_gen( + collated_iter = make_async_gen( base_iterator=batch_iter, - fn=threadpool_computations, + fn=threadpool_computations_format_collate, num_workers=num_threadpool_workers, ) else: - return threadpool_computations(batch_iter) + collated_iter = threadpool_computations_format_collate(batch_iter) + return collated_iter def prefetch_batches_locally( diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index b2f7370a6bb2..c2c40e923c32 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -173,6 +173,8 @@ def collate( Args: batch_iter: An iterator over formatted batches. + collate_fn: A function to apply to each batch. + stats: An optional stats object to record formatting times. """ for batch in batch_iter: with stats.iter_collate_batch_s.timer() if stats else nullcontext(): @@ -180,6 +182,30 @@ def collate( yield CollatedBatch(batch.batch_idx, collated_batch) +def finalize_batches( + batch_iter: Iterator[CollatedBatch], + finalize_fn: Callable[[Any], Any], + stats: Optional[DatasetStats] = None, +) -> Iterator[CollatedBatch]: + """Returns an iterator with the provided finalize_fn applied to items of the batch + iterator. + + This is the same as `collate` except the input batches can be of type Any. + + Args: + batch_iter: An iterator over processed batches. + finalize_fn: A function to apply to each batch. + stats: An optional stats object to record formatting times. + + Returns: + An iterator over batch index and the finalized batch. + """ + for batch in batch_iter: + with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): + finalized_batch = finalize_fn(batch.data) + yield CollatedBatch(batch.batch_idx, finalized_batch) + + def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]: for batch in batch_iter: yield batch.data diff --git a/python/ray/data/_internal/iterator/pipelined_iterator.py b/python/ray/data/_internal/iterator/pipelined_iterator.py index d806d4d00f8b..d4778e531062 100644 --- a/python/ray/data/_internal/iterator/pipelined_iterator.py +++ b/python/ray/data/_internal/iterator/pipelined_iterator.py @@ -67,6 +67,7 @@ def iter_batches( local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + _finalize_fn: Optional[Callable[[Any], Any]] = None, # Deprecated. prefetch_blocks: int = 0, ) -> Iterator[DataBatch]: @@ -79,6 +80,7 @@ def iter_batches( local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, _collate_fn=_collate_fn, + _finalize_fn=_finalize_fn, prefetch_blocks=prefetch_blocks, ) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 099d6361bd8b..2fb8795ed628 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -249,6 +249,7 @@ def __init__( self.iter_next_batch_s: Timer = Timer() self.iter_format_batch_s: Timer = Timer() self.iter_collate_batch_s: Timer = Timer() + self.iter_finalize_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_total_s: Timer = Timer() @@ -320,6 +321,7 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_next_batch_s, self.iter_format_batch_s, self.iter_collate_batch_s, + self.iter_finalize_batch_s, self.iter_total_blocked_s, self.iter_user_s, self.iter_total_s, @@ -726,6 +728,8 @@ class IterStatsSummary: format_time: Timer # Time spent in collate fn, in seconds collate_time: Timer + # Time spent in finalize_fn, in seconds + finalize_batch_time: Timer # Total time user thread is blocked by iter_batches block_time: Timer # Time spent in user code, in seconds @@ -754,6 +758,7 @@ def to_string(self) -> str: or self.next_time.get() or self.format_time.get() or self.collate_time.get() + or self.finalize_batch_time.get() ): out += "\nDataset iterator time breakdown:\n" if self.block_time.get(): @@ -808,6 +813,16 @@ def to_string(self) -> str: fmt(self.collate_time.avg()), fmt(self.collate_time.get()), ) + if self.finalize_batch_time.get(): + format_str = ( + " * In host->device transfer: {} min, {} max, {} avg, {} total\n" + ) + out += format_str.format( + fmt(self.finalize_batch_time.min()), + fmt(self.finalize_batch_time.max()), + fmt(self.finalize_batch_time.avg()), + fmt(self.finalize_batch_time.get()), + ) return out @@ -875,6 +890,7 @@ def __init__(self, *, max_history: int = 3): "iter_next_batch_s": Timer(), "iter_format_batch_s": Timer(), "iter_collate_batch_s": Timer(), + "iter_finalize_batch_s": Timer(), "iter_user_s": Timer(), "iter_total_s": Timer(), } diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 160c23746a75..bbf1e083a260 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -1112,13 +1112,29 @@ def iter_torch_batches( :py:meth:`Dataset.iter_torch_batches ` over the stream of output batches from the pipeline.""" - return DataIterator.iter_torch_batches( - self, + + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + + if collate_fn is not None and (dtypes is not None or device is not None): + raise ValueError( + "collate_fn cannot be used with dtypes and device. It is expected that" + "the provided `collate_fn` will move the output Torch tensors to the" + "appropriate dtype and device." + ) + + if collate_fn is None: + + def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): + return convert_ndarray_batch_to_torch_tensor_batch( + batch, dtypes=dtypes, device=device + ) + + return self.iter_batches( prefetch_blocks=prefetch_blocks, batch_size=batch_size, - dtypes=dtypes, - device=device, - collate_fn=collate_fn, + _collate_fn=collate_fn, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 205de26ed22d..39d21da173a9 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -94,6 +94,7 @@ def iter_batches( local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + _finalize_fn: Optional[Callable[[Any], Any]] = None, # Deprecated. prefetch_blocks: int = 0, ) -> Iterator[DataBatch]: @@ -163,6 +164,15 @@ def drop_metadata(block_iterator): for block_ref, metadata in block_iterator: yield block_ref + # Legacy iter_batches does not have a distinction between + # collate_fn and finalize_fn since batches are not prefetched. + def collate_and_finalize(batch): + if _collate_fn is not None: + batch = _collate_fn(batch) + if _finalize_fn is not None: + batch = _finalize_fn(batch) + return batch + yield from batch_block_refs( drop_metadata(block_iterator), stats=stats, @@ -171,7 +181,7 @@ def drop_metadata(block_iterator): batch_size=batch_size, batch_format=batch_format, drop_last=drop_last, - collate_fn=_collate_fn, + collate_fn=collate_and_finalize, shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, ) @@ -184,6 +194,7 @@ def drop_metadata(block_iterator): batch_format=batch_format, drop_last=drop_last, collate_fn=_collate_fn, + finalize_fn=_finalize_fn, shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, prefetch_batches=prefetch_batches, @@ -285,13 +296,14 @@ def iter_torch_batches( will be inferred from the tensor data. device: The device on which the tensor should be placed; if None, the Torch tensor will be constructed on the CPU. - collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. - Potential use cases include collating along a dimension other than the - first, padding sequences of various lengths, or generally handling - batches of different length tensors. If not provided, the default - collate function is used which simply converts the batch of numpy - arrays to a batch of PyTorch tensors. This API is still experimental - and is subject to change. + collate_fn: A function to apply to each data batch before returning it. When + this parameter is specified, the user should manually handle the host + to device data transfer outside of collate_fn. Potential use cases + include collating along a dimension other than the first, padding + sequences of various lengths, or generally handling batches of different + length tensors. This API is still experimental and is subject to change. + This parameter cannot be used in conjunction with ``dtypes`` or + ``device``. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -314,24 +326,43 @@ def iter_torch_batches( if collate_fn is not None and (dtypes is not None or device is not None): raise ValueError( - "collate_fn cannot be used with dtypes and device. It is expected that" - "the provided `collate_fn` will move the output Torch tensors to the" - "appropriate dtype and device." + "collate_fn cannot be used with dtypes and device." + "You should manually move the output Torch tensors to the" + "desired dtype and device outside of collate_fn." ) if collate_fn is None: - # Automatically move torch tensors to the appropriate device. if device is None: default_device = get_device() if default_device.type != "cpu": device = default_device + # The default collate_fn handles formatting and Tensor creation. + # Here, we set device=None to defer host to device data transfer + # to the subsequent finalize_fn. def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): return convert_ndarray_batch_to_torch_tensor_batch( - batch, dtypes=dtypes, device=device + batch, + dtypes=dtypes, + device=None, ) + # The default finalize_fn handles the host to device data transfer. + # This is executed in a 1-thread pool separately from collate_fn + # to allow independent parallelism of these steps. + def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]): + if device is not None: + if isinstance(batch, dict): + for k, t in batch.items(): + batch[k] = t.to(device=device) + else: + batch = batch.to(device=device) + return batch + + else: + finalize_fn = None + yield from self.iter_batches( prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, @@ -340,6 +371,7 @@ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, _collate_fn=collate_fn, + _finalize_fn=finalize_fn, ) def iter_tf_batches( diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b76d51f41ed7..ea2dbc9e0a1a 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,4 +1,6 @@ import itertools +import queue +import threading import time from typing import Iterator, List, Tuple @@ -96,6 +98,44 @@ def test_restore_from_original_order(): assert idx == [0, 1, 2, 3] +def test_finalize_fn_uses_single_thread(ray_start_regular_shared): + """Tests that finalize_fn is not run with multiple threads.""" + block_refs_iter = itertools.starmap( + lambda block, metadata: (ray.put(block), metadata), + block_generator(num_blocks=20, num_rows=2), + ) + + q = queue.Queue() + semaphore = threading.Semaphore(value=1) + + def finalize_enforce_single_thread(batch): + already_acquired = not semaphore.acquire(blocking=False) + if already_acquired: + e = AssertionError("finalize_fn is being run concurrently.") + q.put(e, block=True) + semaphore.release() + return batch + + # Test that finalize_fn is called in a single thread, + # even if prefetch_batches is set. + output_batches = iter_batches( + block_refs_iter, + collate_fn=lambda batch: batch, + finalize_fn=finalize_enforce_single_thread, + prefetch_batches=4, + ) + + # Force execution of the iterator. + # This step should not raise an exception. + list(output_batches) + + try: + e = q.get(block=False, timeout=0.1) + raise e + except queue.Empty: + pass + + # Test for 3 cases # 1. Batch size is less than block size # 2. Batch size is more than block size diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index d44fa3673fe9..08928d5d6800 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -13,6 +13,7 @@ _calculate_ref_hits, blocks_to_batches, collate, + finalize_batches, format_batches, make_async_gen, resolve_block_refs, @@ -95,6 +96,21 @@ def collate_fn(batch): assert batch.data == pa.table({"bar": [1] * 2}) +def test_finalize(): + def finalize_fn(batch): + return pa.table({"bar": [1] * 2}) + + batches = [ + Batch(i, data) + for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) + ] + batch_iter = finalize_batches(batches, finalize_fn=finalize_fn) + + for i, batch in enumerate(batch_iter): + assert batch.batch_idx == i + assert batch.data == pa.table({"bar": [1] * 2}) + + def test_make_async_gen_fail(): """Tests that any errors raised in async threads are propagated to the main thread.""" diff --git a/python/ray/data/tests/test_iterator.py b/python/ray/data/tests/test_iterator.py index 5e6056ba8dcf..9369772b5d44 100644 --- a/python/ray/data/tests/test_iterator.py +++ b/python/ray/data/tests/test_iterator.py @@ -1,5 +1,5 @@ from typing import Dict -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -141,10 +141,22 @@ def test_tf_conversion_pipeline(ray_start_regular_shared): def test_torch_conversion(ray_start_regular_shared): ds = ray.data.range(5) it = ds.iterator() + it.iter_batches = MagicMock() + for batch in it.iter_torch_batches(): assert isinstance(batch["id"], torch.Tensor) assert batch["id"].tolist() == list(range(5)) + # When collate_fn is not specified, check that the default + # `_collate_fn` (handles formatting and Tensor creation) + # and `_finalize_fn` (handles host to device data transfer) + # are used in `DataIterator.iter_batches()`. + iter_batches_calls_kwargs = [a.kwargs for a in it.iter_batches.call_args_list] + assert all( + callable(kwargs["_collate_fn"]) and callable(kwargs["_finalize_fn"]) + for kwargs in iter_batches_calls_kwargs + ), iter_batches_calls_kwargs + def test_torch_conversion_pipeline(ray_start_regular_shared): ds = ray.data.range(5).repeat(2) @@ -192,11 +204,20 @@ def collate_fn(batch: Dict[str, np.ndarray]): "ray.air._internal.torch_utils.get_device", lambda: torch.device("cuda") ): assert ray.air._internal.torch_utils.get_device().type == "cuda" + + it.iter_batches = MagicMock() for batch in it.iter_torch_batches(collate_fn=collate_fn): assert batch.device.type == "cpu" assert isinstance(batch, torch.Tensor) assert batch.tolist() == list(range(5, 10)) + # When collate_fn is specified, check that`_finalize_fn` + # is not used in `DataIterator.iter_batches()`. + iter_batches_calls_kwargs = [a.kwargs for a in it.iter_batches.call_args_list] + assert all( + kwargs["_finalize_fn"] is None for kwargs in iter_batches_calls_kwargs + ), iter_batches_calls_kwargs + if __name__ == "__main__": import sys