Skip to content

Commit

Permalink
[AIR] Automatically move DatasetIterator torch tensors to correct d…
Browse files Browse the repository at this point in the history
…evice (ray-project#31753)

When DatasetIterator is used with Ray Train, automatically move the torch tensors returned by iter_torch_batches to the correct device.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
amogkam authored and edoakes committed Mar 22, 2023
1 parent b17183f commit 03f48c3
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 68 deletions.
4 changes: 1 addition & 3 deletions doc/source/ray-air/doc_code/hvd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ def train_loop_per_worker():
for epoch in range(num_epochs):
model.train()
for batch in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float, device=train.torch.get_device()
batch_size=32, dtypes=torch.float
):
inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"]
inputs.to(device)
labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion doc/source/ray-air/doc_code/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def train_loop_per_worker():

for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float, device=train.torch.get_device()
batch_size=32, dtypes=torch.float
):
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
Expand Down
32 changes: 4 additions & 28 deletions python/ray/data/_internal/bulk_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Iterator
import warnings

from ray.data.block import DataBatch
from ray.data.dataset_iterator import DatasetIterator
Expand All @@ -9,6 +8,7 @@
import torch
from ray.data import Dataset
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from ray.train._internal.dataset_iterator import TrainDatasetIterator


class BulkDatasetIterator(DatasetIterator):
Expand Down Expand Up @@ -87,31 +87,7 @@ def to_tf(
def stats(self) -> str:
return self._base_dataset.stats()

def _with_backward_compat(self) -> DatasetIterator:
return BulkDatasetIteratorWithBackwardCompat(self)
def _to_train_iterator(self) -> "TrainDatasetIterator":
from ray.train._internal.dataset_iterator import TrainDatasetIterator


class BulkDatasetIteratorWithBackwardCompat(BulkDatasetIterator):
def __init__(
self,
dataset_iterator: BulkDatasetIterator,
):
self._dataset_iterator = dataset_iterator

def __getattr__(self, name):
if name == "_dataset_iterator":
raise AttributeError

if hasattr(self._dataset_iterator, name):
return getattr(self._dataset_iterator, name)

warnings.warn(
"session.get_dataset_shard returns a ray.data.DatasetIterator "
"instead of a Dataset as of Ray v2.3. "
"Use iter_torch_batches(), to_tf(), or iter_batches() to "
"iterate over one epoch. See "
"https://docs.ray.io/en/latest/data/api/dataset_iterator.html "
"for full DatasetIterator docs."
)

return getattr(self._dataset_iterator._base_dataset, name)
return TrainDatasetIterator(self)
32 changes: 4 additions & 28 deletions python/ray/data/_internal/pipelined_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Iterator
import warnings

from ray.data import Dataset
from ray.data.block import DataBatch
Expand All @@ -10,6 +9,7 @@
import torch
from ray.data import DatasetPipeline
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from ray.train._internal.dataset_iterator import TrainDatasetIterator


class PipelinedDatasetIterator(DatasetIterator):
Expand Down Expand Up @@ -105,31 +105,7 @@ def to_tf(
def stats(self) -> str:
return self._base_dataset_pipeline.stats()

def _with_backward_compat(self) -> DatasetIterator:
return PipelinedDatasetIteratorWithBackwardCompat(self)
def _to_train_iterator(self) -> "TrainDatasetIterator":
from ray.train._internal.dataset_iterator import TrainDatasetIterator


class PipelinedDatasetIteratorWithBackwardCompat(PipelinedDatasetIterator):
def __init__(
self,
dataset_iterator: PipelinedDatasetIterator,
):
self._dataset_iterator = dataset_iterator

def __getattr__(self, name):
if name == "_dataset_iterator":
raise AttributeError

if hasattr(self._dataset_iterator, name):
return getattr(self._dataset_iterator, name)

warnings.warn(
"session.get_dataset_shard returns a ray.data.DatasetIterator "
"instead of a DatasetPipeline as of Ray v2.3. "
"Use iter_torch_batches(), to_tf(), or iter_batches() to "
"iterate over one epoch. See "
"https://docs.ray.io/en/latest/data/api/dataset_iterator.html "
"for full DatasetIterator docs."
)

return getattr(self._dataset_iterator._base_dataset_pipeline, name)
return TrainDatasetIterator(self)
11 changes: 8 additions & 3 deletions python/ray/data/dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorflow as tf
import torch
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from ray.train._internal.dataset_iterator import TrainDatasetIterator


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -247,9 +248,13 @@ def iter_epochs(self, max_epoch: int = -1) -> None:
"iter_torch_batches(), or to_tf()."
)

@abc.abstractmethod
def _with_backward_compat(self) -> "DatasetIterator":
def _to_train_iterator(self) -> "TrainDatasetIterator":
"""
Provide backwards compatibility for AIR users.
Convert this DatasetIterator to one that is specific
to Ray Train Trainers.
The Train-specific iterator has training specific logic,
for example, automatically moving batches to GPU when GPU training
is enabled.
"""
raise NotImplementedError
67 changes: 67 additions & 0 deletions python/ray/train/_internal/dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Iterator, Optional, TYPE_CHECKING
import warnings

from ray.data.block import DataBatch
from ray.data.dataset_iterator import DatasetIterator
from ray.train.error import SessionMisuseError

if TYPE_CHECKING:
import tensorflow as tf
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType


class TrainDatasetIterator(DatasetIterator):
"""A DatasetIterator with Ray Train specific logic.
Args:
dataset_iterator: The base dataset iterator.
"""

def __init__(
self,
dataset_iterator: DatasetIterator,
):
self._dataset_iterator = dataset_iterator

def iter_batches(self, *args, **kwargs) -> Iterator["DataBatch"]:
return self._dataset_iterator.iter_batches(*args, **kwargs)

def iter_torch_batches(
self, *, device: Optional[str] = None, **kwargs
) -> Iterator["TorchTensorBatchType"]:

# Automatically move torch tensors to the appropriate device.
if device is None:
from ray.train.torch import get_device

try:
device = get_device()
except SessionMisuseError:
pass

return self._dataset_iterator.iter_torch_batches(device=device, **kwargs)

def to_tf(self, *args, **kwargs) -> "tf.data.Dataset":
return self._dataset_iterator.to_tf(*args, **kwargs)

def stats(self) -> str:
return self._dataset_iterator.stats()

def __getattr__(self, name):
if name == "_dataset_iterator":
raise AttributeError

if hasattr(self._dataset_iterator, name):
return getattr(self._dataset_iterator, name)

# Warning for backwards compatibility.
warnings.warn(
"session.get_dataset_shard returns a ray.data.DatasetIterator "
"instead of a Dataset/DatasetPipeline as of Ray v2.3. "
"Use iter_torch_batches(), to_tf(), or iter_batches() to "
"iterate over one epoch. See "
"https://docs.ray.io/en/latest/data/api/dataset_iterator.html "
"for full DatasetIterator docs."
)

return getattr(self._dataset_iterator._base_dataset, name)
2 changes: 1 addition & 1 deletion python/ray/train/_internal/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_dataset_shards(
dataset_splits = [dataset] * len(training_worker_handles)

for i, dataset_split in enumerate(dataset_splits):
dataset_splits[i] = dataset_split.iterator()._with_backward_compat()
dataset_splits[i] = dataset_split.iterator()._to_train_iterator()

for i in range(len(dataset_splits)):
dataset_dict_splits[i][key] = dataset_splits[i]
Expand Down
4 changes: 1 addition & 3 deletions python/ray/train/horovod/horovod_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ def train_loop_per_worker():
for epoch in range(num_epochs):
model.train()
for batch in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float, device=train.torch.get_device()
batch_size=32, dtypes=torch.float
):
inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"]
inputs.to(device)
labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
Expand Down
31 changes: 31 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from unittest.mock import patch
import pytest
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler

import ray
import ray.data
from ray.exceptions import RayTaskError
from ray.air import session
from ray import tune
Expand Down Expand Up @@ -335,6 +337,35 @@ def train_fn():
trainer.fit()


@pytest.mark.parametrize("use_gpu", (True, False))
def test_torch_iter_torch_batches_auto_device(ray_start_4_cpus_2_gpus, use_gpu):
"""
Tests that iter_torch_batches in TorchTrainer worker function uses the
default device.
"""

def train_fn():
dataset = session.get_dataset_shard("train")
for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"):
assert str(batch.device) == "cpu"

# Autodetect
for batch in dataset.iter_torch_batches(dtypes=torch.float):
assert str(batch.device) == str(train.torch.get_device())

dataset = ray.data.from_numpy(np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]).T)
# Test that this works outside a Train function
for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"):
assert str(batch.device) == "cpu"

trainer = TorchTrainer(
train_fn,
scaling_config=ScalingConfig(num_workers=2, use_gpu=use_gpu),
datasets={"train": dataset},
)
trainer.fit()


if __name__ == "__main__":
import sys

Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def train_loop_per_worker():
# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(batch_size=32,
dtypes=torch.float, device=train.torch.get_device()):
dtypes=torch.float):
# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
Expand Down

0 comments on commit 03f48c3

Please sign in to comment.