-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove obsolete device special cases #6197
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Can you add test for U16, F64 etc on TPU CI to make sure they are actually natively supported?
I tested this manually, but I'll make sure there's a unit test for each one. |
Added a test that iterates through the dtypes that both torch and XLA support. I found the list of supported dtypes on TPU internally, which does not include complex128, so using that dtype on TPU will trigger an error. The others are all supported now. We have two options for complex128:
|
yea.. I don't know how many people actually uses complex128, is it supported on GPU? As we are increasing the effort on GPU this year, I don't want to regress PyTorch/XLA:GPU |
This PR won't affect behavior on GPU. I'm basically just removing the downcast when creating tensors on TPUs, and we never changed the dtype on GPU. If the underlying XLA:GPU client doesn't support complex128 (it probably does), then it would just throw an error. |
@@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( | |||
if (UseBF16()) { | |||
return xla::PrimitiveType::BF16; | |||
} | |||
if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) || | |||
hw_type == XlaDeviceType::NEURON) { | |||
if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffhataws do you know what dtype that nuron device does not support?
@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.
Yeah, if we keep this downcasting behavior, I will consolidate it in the DevicePlugin
API (see #6242)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is best to keep fp32 since it is currently supported as in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/data-types.html.
Hmm, both CPU and GPU fail the test I added with
I'll just remove this test case for now. |
No description provided.