forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AIR] Automatically move
DatasetIterator
torch tensors to correct d…
…evice (ray-project#31753) When DatasetIterator is used with Ray Train, automatically move the torch tensors returned by iter_torch_batches to the correct device. Signed-off-by: amogkam <amogkamsetty@yahoo.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
- Loading branch information
Showing
10 changed files
with
119 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Iterator, Optional, TYPE_CHECKING | ||
import warnings | ||
|
||
from ray.data.block import DataBatch | ||
from ray.data.dataset_iterator import DatasetIterator | ||
from ray.train.error import SessionMisuseError | ||
|
||
if TYPE_CHECKING: | ||
import tensorflow as tf | ||
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType | ||
|
||
|
||
class TrainDatasetIterator(DatasetIterator): | ||
"""A DatasetIterator with Ray Train specific logic. | ||
Args: | ||
dataset_iterator: The base dataset iterator. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset_iterator: DatasetIterator, | ||
): | ||
self._dataset_iterator = dataset_iterator | ||
|
||
def iter_batches(self, *args, **kwargs) -> Iterator["DataBatch"]: | ||
return self._dataset_iterator.iter_batches(*args, **kwargs) | ||
|
||
def iter_torch_batches( | ||
self, *, device: Optional[str] = None, **kwargs | ||
) -> Iterator["TorchTensorBatchType"]: | ||
|
||
# Automatically move torch tensors to the appropriate device. | ||
if device is None: | ||
from ray.train.torch import get_device | ||
|
||
try: | ||
device = get_device() | ||
except SessionMisuseError: | ||
pass | ||
|
||
return self._dataset_iterator.iter_torch_batches(device=device, **kwargs) | ||
|
||
def to_tf(self, *args, **kwargs) -> "tf.data.Dataset": | ||
return self._dataset_iterator.to_tf(*args, **kwargs) | ||
|
||
def stats(self) -> str: | ||
return self._dataset_iterator.stats() | ||
|
||
def __getattr__(self, name): | ||
if name == "_dataset_iterator": | ||
raise AttributeError | ||
|
||
if hasattr(self._dataset_iterator, name): | ||
return getattr(self._dataset_iterator, name) | ||
|
||
# Warning for backwards compatibility. | ||
warnings.warn( | ||
"session.get_dataset_shard returns a ray.data.DatasetIterator " | ||
"instead of a Dataset/DatasetPipeline as of Ray v2.3. " | ||
"Use iter_torch_batches(), to_tf(), or iter_batches() to " | ||
"iterate over one epoch. See " | ||
"https://docs.ray.io/en/latest/data/api/dataset_iterator.html " | ||
"for full DatasetIterator docs." | ||
) | ||
|
||
return getattr(self._dataset_iterator._base_dataset, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters