Skip to content

Commit

Permalink
[AIR/Data] Add collate_fn to iter_torch_batches (ray-project#32412)
Browse files Browse the repository at this point in the history
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.

Closes ray-project#32224.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
2 people authored and edoakes committed Mar 22, 2023
1 parent 7c9f3c6 commit 3266fda
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 24 deletions.
12 changes: 12 additions & 0 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,23 @@ def convert_ndarray_to_torch_tensor(
ndarray: A NumPy ndarray that we wish to convert to a Torch Tensor.
dtype: A Torch dtype for the created tensor; if None, the dtype will be
inferred from the NumPy ndarray data.
device: The device on which the tensor(s) should be placed; if None, the Torch
tensor(s) will be constructed on the CPU.
Returns: A Torch Tensor.
"""
ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)

# Object dtype cannot be converted into PyTorch Tensor.
if ndarray.dtype.type is np.object_:
raise RuntimeError(
"Numpy array of object dtype cannot be converted to a Torch Tensor. This"
"may because the numpy array is a ragged tensor- it contains items of"
"different sizes. If using `iter_torch_batches()` API, you can pass in a"
"`collate_fn` argument to specify custom logic to convert the Numpy array"
"batch to a Torch tensor batch."
)

# The numpy array is not always writeable as it can come from the Ray object store.
# Numpy will throw a verbose warning here, which we suppress, as we don't write
# to the tensors. We also don't want to copy the array to avoid memory overhead.
Expand Down
14 changes: 13 additions & 1 deletion python/ray/data/_internal/block_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import queue
import sys
import threading
from typing import Iterator, Optional, TypeVar, Union
from typing import Any, Callable, Iterator, Optional, TypeVar, Union

import ray
from ray.actor import ActorHandle
Expand Down Expand Up @@ -39,6 +39,7 @@ def batch_block_refs(
batch_size: Optional[int] = None,
batch_format: str = "default",
drop_last: bool = False,
collate_fn: Optional[Callable[[DataBatch], Any]] = None,
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
Expand Down Expand Up @@ -68,6 +69,7 @@ def batch_block_refs(
select ``pandas.DataFrame`` or "pyarrow" to select
``pyarrow.Table``. Default is "default".
drop_last: Whether to drop the last batch if it's incomplete.
collate_fn: A function to apply to each data batch before returning it.
shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a
local in-memory shuffle buffer, and this value will serve as the minimum
number of rows that must be in the local in-memory shuffle buffer in order
Expand Down Expand Up @@ -114,6 +116,7 @@ def batch_block_refs(
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
collate_fn=collate_fn,
shuffle_buffer_min_size=shuffle_buffer_min_size,
shuffle_seed=shuffle_seed,
ensure_copy=ensure_copy,
Expand All @@ -128,6 +131,7 @@ def batch_blocks(
batch_size: Optional[int] = None,
batch_format: str = "default",
drop_last: bool = False,
collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None,
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
Expand All @@ -154,6 +158,14 @@ def batch_blocks(
stats=stats,
)

if collate_fn is not None:

def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]:
for batch in iterator:
yield collate_fn(batch)

batch_iter = batch_fn_iter(batch_iter)

if prefetch_batches > 0:
batch_iter = _make_async_gen(batch_iter, prefetch_buffer_size=prefetch_batches)

Expand Down
7 changes: 6 additions & 1 deletion python/ray/data/_internal/bulk_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Dict, Optional, Union, Iterator
import numpy as np
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Iterator

from ray.data.block import DataBatch
from ray.data.dataset_iterator import DatasetIterator
Expand Down Expand Up @@ -49,6 +50,9 @@ def iter_torch_batches(
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
drop_last: bool = False,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
) -> Iterator["TorchTensorBatchType"]:
Expand All @@ -57,6 +61,7 @@ def iter_torch_batches(
batch_size=batch_size,
dtypes=dtypes,
device=device,
collate_fn=collate_fn,
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
Expand Down
22 changes: 10 additions & 12 deletions python/ray/data/_internal/pipelined_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Dict, Optional, Union, Iterator
import numpy as np
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Iterator

from ray.data.block import DataBatch
from ray.data.dataset_iterator import DatasetIterator
Expand Down Expand Up @@ -56,27 +57,24 @@ def iter_torch_batches(
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
drop_last: bool = False,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
) -> Iterator["TorchTensorBatchType"]:
from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
)

ds = self._get_next_dataset()
for batch in ds.iter_batches(
return ds.iter_torch_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
batch_format="numpy",
dtypes=dtypes,
device=device,
drop_last=drop_last,
collate_fn=collate_fn,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
):
yield convert_ndarray_batch_to_torch_tensor_batch(
batch,
dtypes=dtypes,
device=device,
)
)

def stats(self) -> str:
return self._base_dataset_pipeline.stats()
Expand Down
36 changes: 29 additions & 7 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2839,6 +2839,7 @@ def iter_batches(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
) -> Iterator[DataBatch]:
"""Return a local batched iterator over the dataset.
Expand Down Expand Up @@ -2890,6 +2891,7 @@ def iter_batches(
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
collate_fn=_collate_fn,
shuffle_buffer_min_size=local_shuffle_buffer_size,
shuffle_seed=local_shuffle_seed,
)
Expand All @@ -2904,6 +2906,9 @@ def iter_torch_batches(
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
Expand Down Expand Up @@ -2939,6 +2944,13 @@ def iter_torch_batches(
will be inferred from the tensor data.
device: The device on which the tensor should be placed; if None, the Torch
tensor will be constructed on the CPU.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
Potential use cases include collating along a dimension other than the
first, padding sequences of various lengths, or generally handling
batches of different length tensors. If not provided, the default
collate function is used which simply converts the batch of numpy
arrays to a batch of PyTorch tensors. This API is still experimental
and is subject to change.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
Expand All @@ -2957,19 +2969,29 @@ def iter_torch_batches(
convert_ndarray_batch_to_torch_tensor_batch,
)

for batch in self.iter_batches(
if collate_fn is not None and (dtypes is not None or device is not None):
raise ValueError(
"collate_fn cannot be used with dtypes and device. It is expected that"
"the provided `collate_fn` will move the output Torch tensors to the"
"appropriate dtype and device."
)

if collate_fn is None:

def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]):
return convert_ndarray_batch_to_torch_tensor_batch(
batch, dtypes=dtypes, device=device
)

yield from self.iter_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
batch_format="numpy",
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
):
yield convert_ndarray_batch_to_torch_tensor_batch(
batch,
dtypes=dtypes,
device=device,
)
_collate_fn=collate_fn,
)

@ConsumptionAPI
def iter_tf_batches(
Expand Down
12 changes: 11 additions & 1 deletion python/ray/data/dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import numpy as np
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Iterator
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Iterator

from ray.air.util.data_batch_conversion import BlockFormat
from ray.data.block import DataBatch
Expand Down Expand Up @@ -111,6 +111,9 @@ def iter_torch_batches(
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
Expand Down Expand Up @@ -145,6 +148,13 @@ def iter_torch_batches(
will be inferred from the tensor data.
device: The device on which the tensor should be placed; if None, the Torch
tensor will be constructed on the CPU.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
Potential use cases include collating along a dimension other than the
first, padding sequences of various lengths, or generally handling
batches of different length tensors. If not provided, the default
collate function is used which simply converts the batch of numpy
arrays to a batch of PyTorch tensors. This API is still experimental
and is subject to change.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
Expand Down
16 changes: 14 additions & 2 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
import warnings

import numpy as np

import ray
from ray.air.util.data_batch_conversion import BlockFormat
from ray.data._internal import progress_bar
Expand Down Expand Up @@ -54,6 +56,7 @@
import pyarrow
import tensorflow as tf
import torch
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -171,6 +174,7 @@ def iter_batches(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
) -> Iterator[DataBatch]:
"""Return a local batched iterator over the data in the pipeline.
Expand Down Expand Up @@ -231,6 +235,7 @@ def iter_batches(
batch_size=batch_size,
batch_format=batch_format,
drop_last=drop_last,
collate_fn=_collate_fn,
shuffle_buffer_min_size=local_shuffle_buffer_size,
shuffle_seed=local_shuffle_seed,
)
Expand Down Expand Up @@ -1079,18 +1084,25 @@ def iter_torch_batches(
*,
prefetch_blocks: int = 0,
batch_size: Optional[int] = 256,
batch_format: str = "default",
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
) -> Iterator[Union["torch.Tensor", Dict[str, "torch.Tensor"]]]:
) -> Iterator["TorchTensorBatchType"]:
"""Call
:py:meth:`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>`
over the stream of output batches from the pipeline."""
return Dataset.iter_torch_batches(
self,
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
dtypes=dtypes,
device=device,
collate_fn=collate_fn,
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
Expand Down
44 changes: 44 additions & 0 deletions python/ray/data/tests/test_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from typing import Dict

import numpy as np
import tensorflow as tf
import torch

Expand Down Expand Up @@ -112,6 +114,48 @@ def test_torch_conversion(ray_start_regular_shared):
assert batch["value"].tolist() == list(range(5))


def test_torch_conversion_pipeline(ray_start_regular_shared):
ds = ray.data.range_table(5).repeat(2)
it = ds.iterator()

# First epoch.
for batch in it.iter_torch_batches():
assert isinstance(batch["value"], torch.Tensor)
assert batch["value"].tolist() == list(range(5))

# Second epoch.
for batch in it.iter_torch_batches():
assert isinstance(batch["value"], torch.Tensor)
assert batch["value"].tolist() == list(range(5))

# Fails on third iteration.
with pytest.raises(StopIteration):
for batch in it.iter_torch_batches():
pass


def test_torch_conversion_collate_fn(ray_start_regular_shared):
def collate_fn(batch: Dict[str, np.ndarray]):
return torch.as_tensor(batch["value"] + 5)

ds = ray.data.range_table(5)
it = ds.iterator()
for batch in it.iter_torch_batches(collate_fn=collate_fn):
assert isinstance(batch, torch.Tensor)
assert batch.tolist() == list(range(5, 10))

# Should fail.
with pytest.raises(ValueError):
for batch in it.iter_torch_batches(collate_fn=collate_fn, dtypes=torch.float32):
assert isinstance(batch, torch.Tensor)
assert batch.tolist() == list(range(5, 10))

with pytest.raises(ValueError):
for batch in it.iter_torch_batches(collate_fn=collate_fn, device="cpu"):
assert isinstance(batch, torch.Tensor)
assert batch.tolist() == list(range(5, 10))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 3266fda

Please sign in to comment.