-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove obsolete device special cases #6197
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
4513eb3
Remove obsolete device special cases
will-cromar c438e3a
remove comments
will-cromar 825686b
remove device from ConvertTo
will-cromar 95cf14c
remove device from `ConvertToRaw`
will-cromar 76d825f
formatting
will-cromar 6720594
Add test_dtypes
will-cromar ac3d63e
Merge branch 'master' of github.com:pytorch/xla into wcromar/remove-o…
will-cromar 92e9a2c
formatting
will-cromar 5421d64
remove complex128 test
will-cromar bb04b65
Generalize `test_dtypes`
will-cromar 37fff0f
remove complex128 test again
will-cromar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from absl.testing import absltest, parameterized | ||
import torch | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
|
||
|
||
class TestDtypes(parameterized.TestCase): | ||
|
||
@parameterized.parameters(torch.float16, torch.float32, torch.float64, | ||
torch.bfloat16, torch.complex64) | ||
def test_float_round_trip(self, dtype: torch.dtype): | ||
t = torch.randn((3, 3), dtype=dtype) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
@parameterized.parameters( | ||
torch.uint8, | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
) | ||
def test_int_round_trip(self, dtype: torch.dtype): | ||
t = torch.randint(0, 128, (3, 3), dtype=dtype) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
def test_bool_round_trip(self): | ||
t = torch.randint(0, 2, (3, 3), dtype=torch.bool) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffhataws do you know what dtype that nuron device does not support?
@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, if we keep this downcasting behavior, I will consolidate it in the
DevicePlugin
API (see #6242)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is best to keep fp32 since it is currently supported as in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/data-types.html.