-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Automatically move DatasetIterator
torch tensors to correct device
#31753
Conversation
DatasetIterator
torch tensors to correct device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
try: | ||
device = get_device() | ||
except SessionMisuseError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this error is raised when a Train session is never initialized, right? Are we ignoring the error here in case this iterator is used outside of a Ray Train context, making this ~a no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep thats right!
…evice (ray-project#31753) When DatasetIterator is used with Ray Train, automatically move the torch tensors returned by iter_torch_batches to the correct device. Signed-off-by: amogkam <amogkamsetty@yahoo.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
When
DatasetIterator
is used with Ray Train, automatically move the torch tensors returned byiter_torch_batches
to the correct device.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.