From 3266fda7aa31f65ea0d503931bbddd5a659db3a0 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 21 Feb 2023 10:19:24 -0800 Subject: [PATCH] [AIR/Data] Add `collate_fn` to `iter_torch_batches` (#32412) Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training. Closes #32224. Signed-off-by: amogkam Signed-off-by: Balaji Veeramani Co-authored-by: Balaji Veeramani Signed-off-by: Edward Oakes --- python/ray/air/_internal/torch_utils.py | 12 +++++ python/ray/data/_internal/block_batching.py | 14 +++++- .../data/_internal/bulk_dataset_iterator.py | 7 ++- .../_internal/pipelined_dataset_iterator.py | 22 +++++----- python/ray/data/dataset.py | 36 ++++++++++++--- python/ray/data/dataset_iterator.py | 12 ++++- python/ray/data/dataset_pipeline.py | 16 ++++++- .../ray/data/tests/test_dataset_iterator.py | 44 +++++++++++++++++++ 8 files changed, 139 insertions(+), 24 deletions(-) diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 15be5b1d5d344..47e5600d458b4 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -128,11 +128,23 @@ def convert_ndarray_to_torch_tensor( ndarray: A NumPy ndarray that we wish to convert to a Torch Tensor. dtype: A Torch dtype for the created tensor; if None, the dtype will be inferred from the NumPy ndarray data. + device: The device on which the tensor(s) should be placed; if None, the Torch + tensor(s) will be constructed on the CPU. Returns: A Torch Tensor. """ ndarray = _unwrap_ndarray_object_type_if_needed(ndarray) + # Object dtype cannot be converted into PyTorch Tensor. + if ndarray.dtype.type is np.object_: + raise RuntimeError( + "Numpy array of object dtype cannot be converted to a Torch Tensor. This" + "may because the numpy array is a ragged tensor- it contains items of" + "different sizes. If using `iter_torch_batches()` API, you can pass in a" + "`collate_fn` argument to specify custom logic to convert the Numpy array" + "batch to a Torch tensor batch." + ) + # The numpy array is not always writeable as it can come from the Ray object store. # Numpy will throw a verbose warning here, which we suppress, as we don't write # to the tensors. We also don't want to copy the array to avoid memory overhead. diff --git a/python/ray/data/_internal/block_batching.py b/python/ray/data/_internal/block_batching.py index 23fafc160ef40..febee5a06b072 100644 --- a/python/ray/data/_internal/block_batching.py +++ b/python/ray/data/_internal/block_batching.py @@ -3,7 +3,7 @@ import queue import sys import threading -from typing import Iterator, Optional, TypeVar, Union +from typing import Any, Callable, Iterator, Optional, TypeVar, Union import ray from ray.actor import ActorHandle @@ -39,6 +39,7 @@ def batch_block_refs( batch_size: Optional[int] = None, batch_format: str = "default", drop_last: bool = False, + collate_fn: Optional[Callable[[DataBatch], Any]] = None, shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, @@ -68,6 +69,7 @@ def batch_block_refs( select ``pandas.DataFrame`` or "pyarrow" to select ``pyarrow.Table``. Default is "default". drop_last: Whether to drop the last batch if it's incomplete. + collate_fn: A function to apply to each data batch before returning it. shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the minimum number of rows that must be in the local in-memory shuffle buffer in order @@ -114,6 +116,7 @@ def batch_block_refs( batch_size=batch_size, batch_format=batch_format, drop_last=drop_last, + collate_fn=collate_fn, shuffle_buffer_min_size=shuffle_buffer_min_size, shuffle_seed=shuffle_seed, ensure_copy=ensure_copy, @@ -128,6 +131,7 @@ def batch_blocks( batch_size: Optional[int] = None, batch_format: str = "default", drop_last: bool = False, + collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None, shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, @@ -154,6 +158,14 @@ def batch_blocks( stats=stats, ) + if collate_fn is not None: + + def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: + for batch in iterator: + yield collate_fn(batch) + + batch_iter = batch_fn_iter(batch_iter) + if prefetch_batches > 0: batch_iter = _make_async_gen(batch_iter, prefetch_buffer_size=prefetch_batches) diff --git a/python/ray/data/_internal/bulk_dataset_iterator.py b/python/ray/data/_internal/bulk_dataset_iterator.py index 487f125fcd760..abbe85ff2dd9c 100644 --- a/python/ray/data/_internal/bulk_dataset_iterator.py +++ b/python/ray/data/_internal/bulk_dataset_iterator.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Dict, Optional, Union, Iterator +import numpy as np +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Iterator from ray.data.block import DataBatch from ray.data.dataset_iterator import DatasetIterator @@ -49,6 +50,9 @@ def iter_torch_batches( dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, drop_last: bool = False, + collate_fn: Optional[ + Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any] + ] = None, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, ) -> Iterator["TorchTensorBatchType"]: @@ -57,6 +61,7 @@ def iter_torch_batches( batch_size=batch_size, dtypes=dtypes, device=device, + collate_fn=collate_fn, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, diff --git a/python/ray/data/_internal/pipelined_dataset_iterator.py b/python/ray/data/_internal/pipelined_dataset_iterator.py index 32d319f14f5f9..aab9d72f2486a 100644 --- a/python/ray/data/_internal/pipelined_dataset_iterator.py +++ b/python/ray/data/_internal/pipelined_dataset_iterator.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Dict, Optional, Union, Iterator +import numpy as np +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Iterator from ray.data.block import DataBatch from ray.data.dataset_iterator import DatasetIterator @@ -56,27 +57,24 @@ def iter_torch_batches( dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, drop_last: bool = False, + collate_fn: Optional[ + Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any] + ] = None, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, ) -> Iterator["TorchTensorBatchType"]: - from ray.air._internal.torch_utils import ( - convert_ndarray_batch_to_torch_tensor_batch, - ) ds = self._get_next_dataset() - for batch in ds.iter_batches( + return ds.iter_torch_batches( prefetch_blocks=prefetch_blocks, batch_size=batch_size, - batch_format="numpy", + dtypes=dtypes, + device=device, drop_last=drop_last, + collate_fn=collate_fn, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, - ): - yield convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=dtypes, - device=device, - ) + ) def stats(self) -> str: return self._base_dataset_pipeline.stats() diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index a18d028d8f932..69a2e6a0d6e81 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2839,6 +2839,7 @@ def iter_batches( drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], Any]] = None, ) -> Iterator[DataBatch]: """Return a local batched iterator over the dataset. @@ -2890,6 +2891,7 @@ def iter_batches( batch_size=batch_size, batch_format=batch_format, drop_last=drop_last, + collate_fn=_collate_fn, shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, ) @@ -2904,6 +2906,9 @@ def iter_torch_batches( batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, + collate_fn: Optional[ + Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -2939,6 +2944,13 @@ def iter_torch_batches( will be inferred from the tensor data. device: The device on which the tensor should be placed; if None, the Torch tensor will be constructed on the CPU. + collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. + Potential use cases include collating along a dimension other than the + first, padding sequences of various lengths, or generally handling + batches of different length tensors. If not provided, the default + collate function is used which simply converts the batch of numpy + arrays to a batch of PyTorch tensors. This API is still experimental + and is subject to change. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the @@ -2957,19 +2969,29 @@ def iter_torch_batches( convert_ndarray_batch_to_torch_tensor_batch, ) - for batch in self.iter_batches( + if collate_fn is not None and (dtypes is not None or device is not None): + raise ValueError( + "collate_fn cannot be used with dtypes and device. It is expected that" + "the provided `collate_fn` will move the output Torch tensors to the" + "appropriate dtype and device." + ) + + if collate_fn is None: + + def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): + return convert_ndarray_batch_to_torch_tensor_batch( + batch, dtypes=dtypes, device=device + ) + + yield from self.iter_batches( prefetch_blocks=prefetch_blocks, batch_size=batch_size, batch_format="numpy", drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, - ): - yield convert_ndarray_batch_to_torch_tensor_batch( - batch, - dtypes=dtypes, - device=device, - ) + _collate_fn=collate_fn, + ) @ConsumptionAPI def iter_tf_batches( diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index bc26de840c7ed..ec01044d3a608 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -1,7 +1,7 @@ import abc import numpy as np import sys -from typing import TYPE_CHECKING, Dict, List, Optional, Union, Iterator +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Iterator from ray.air.util.data_batch_conversion import BlockFormat from ray.data.block import DataBatch @@ -111,6 +111,9 @@ def iter_torch_batches( batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, + collate_fn: Optional[ + Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -145,6 +148,13 @@ def iter_torch_batches( will be inferred from the tensor data. device: The device on which the tensor should be placed; if None, the Torch tensor will be constructed on the CPU. + collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. + Potential use cases include collating along a dimension other than the + first, padding sequences of various lengths, or generally handling + batches of different length tensors. If not provided, the default + collate function is used which simply converts the batch of numpy + arrays to a batch of PyTorch tensors. This API is still experimental + and is subject to change. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index d67c5f36540d9..515c88ac983e1 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -16,6 +16,8 @@ ) import warnings +import numpy as np + import ray from ray.air.util.data_batch_conversion import BlockFormat from ray.data._internal import progress_bar @@ -54,6 +56,7 @@ import pyarrow import tensorflow as tf import torch + from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType logger = logging.getLogger(__name__) @@ -171,6 +174,7 @@ def iter_batches( drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], Any]] = None, ) -> Iterator[DataBatch]: """Return a local batched iterator over the data in the pipeline. @@ -231,6 +235,7 @@ def iter_batches( batch_size=batch_size, batch_format=batch_format, drop_last=drop_last, + collate_fn=_collate_fn, shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, ) @@ -1079,11 +1084,15 @@ def iter_torch_batches( *, prefetch_blocks: int = 0, batch_size: Optional[int] = 256, - batch_format: str = "default", + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: Optional[str] = None, + collate_fn: Optional[ + Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any] + ] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, - ) -> Iterator[Union["torch.Tensor", Dict[str, "torch.Tensor"]]]: + ) -> Iterator["TorchTensorBatchType"]: """Call :py:meth:`Dataset.iter_torch_batches ` over the stream of output batches from the pipeline.""" @@ -1091,6 +1100,9 @@ def iter_torch_batches( self, prefetch_blocks=prefetch_blocks, batch_size=batch_size, + dtypes=dtypes, + device=device, + collate_fn=collate_fn, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, diff --git a/python/ray/data/tests/test_dataset_iterator.py b/python/ray/data/tests/test_dataset_iterator.py index 6dc09604c3d99..0e9f55da7d014 100644 --- a/python/ray/data/tests/test_dataset_iterator.py +++ b/python/ray/data/tests/test_dataset_iterator.py @@ -1,5 +1,7 @@ import pytest +from typing import Dict +import numpy as np import tensorflow as tf import torch @@ -112,6 +114,48 @@ def test_torch_conversion(ray_start_regular_shared): assert batch["value"].tolist() == list(range(5)) +def test_torch_conversion_pipeline(ray_start_regular_shared): + ds = ray.data.range_table(5).repeat(2) + it = ds.iterator() + + # First epoch. + for batch in it.iter_torch_batches(): + assert isinstance(batch["value"], torch.Tensor) + assert batch["value"].tolist() == list(range(5)) + + # Second epoch. + for batch in it.iter_torch_batches(): + assert isinstance(batch["value"], torch.Tensor) + assert batch["value"].tolist() == list(range(5)) + + # Fails on third iteration. + with pytest.raises(StopIteration): + for batch in it.iter_torch_batches(): + pass + + +def test_torch_conversion_collate_fn(ray_start_regular_shared): + def collate_fn(batch: Dict[str, np.ndarray]): + return torch.as_tensor(batch["value"] + 5) + + ds = ray.data.range_table(5) + it = ds.iterator() + for batch in it.iter_torch_batches(collate_fn=collate_fn): + assert isinstance(batch, torch.Tensor) + assert batch.tolist() == list(range(5, 10)) + + # Should fail. + with pytest.raises(ValueError): + for batch in it.iter_torch_batches(collate_fn=collate_fn, dtypes=torch.float32): + assert isinstance(batch, torch.Tensor) + assert batch.tolist() == list(range(5, 10)) + + with pytest.raises(ValueError): + for batch in it.iter_torch_batches(collate_fn=collate_fn, device="cpu"): + assert isinstance(batch, torch.Tensor) + assert batch.tolist() == list(range(5, 10)) + + if __name__ == "__main__": import sys