Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove obsolete device special cases #6197

Merged
merged 11 commits into from
Jan 3, 2024
Merged

Conversation

will-cromar
Copy link
Collaborator

No description provided.

@will-cromar will-cromar marked this pull request as ready for review December 18, 2023 22:25
@will-cromar will-cromar changed the title [WIP] Remove obsolete device special cases Remove obsolete device special cases Dec 18, 2023
@will-cromar will-cromar requested a review from JackCaoG December 18, 2023 22:25
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you add test for U16, F64 etc on TPU CI to make sure they are actually natively supported?

@will-cromar
Copy link
Collaborator Author

Thanks! Can you add test for U16, F64 etc on TPU CI to make sure they are actually natively supported?

I tested this manually, but I'll make sure there's a unit test for each one.

@will-cromar
Copy link
Collaborator Author

Thanks! Can you add test for U16, F64 etc on TPU CI to make sure they are actually natively supported?

Added a test that iterates through the dtypes that both torch and XLA support.

I found the list of supported dtypes on TPU internally, which does not include complex128, so using that dtype on TPU will trigger an error. The others are all supported now. We have two options for complex128:

  1. Change our behavior to throw an error for unsupported dtypes. IMO this is the safer behavior. Plus, SPMD obscures the actual device type, which makes it hard to determine what the "real" device is without accessing runtime state.
  2. Add downcasting back for complex128 only.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jan 2, 2024

yea.. I don't know how many people actually uses complex128, is it supported on GPU? As we are increasing the effort on GPU this year, I don't want to regress PyTorch/XLA:GPU

@will-cromar
Copy link
Collaborator Author

yea.. I don't know how many people actually uses complex128, is it supported on GPU? As we are increasing the effort on GPU this year, I don't want to regress PyTorch/XLA:GPU

This PR won't affect behavior on GPU. I'm basically just removing the downcast when creating tensors on TPUs, and we never changed the dtype on GPU. If the underlying XLA:GPU client doesn't support complex128 (it probably does), then it would just throw an error.

@@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) ||
hw_type == XlaDeviceType::NEURON) {
if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffhataws do you know what dtype that nuron device does not support?

@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.

Yeah, if we keep this downcasting behavior, I will consolidate it in the DevicePlugin API (see #6242)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is best to keep fp32 since it is currently supported as in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/data-types.html.

@will-cromar
Copy link
Collaborator Author

Hmm, both CPU and GPU fail the test I added with complex128:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/absl/testing/parameterized.py", line 323, in bound_param_test
    return test_method(self, testcase_params)
  File "/tmp/pytorch/xla/test/pjrt/test_dtypes.py", line 22, in test_float_round_trip
    torch.testing.assert_close(xt.cpu(), t)
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1520, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 9 / 9 (100.0%)
Greatest absolute difference: 1.3942428785896892 at index (2, 1) (up to 1e-07 allowed)
Greatest relative difference: 0.9994087028336766 at index (1, 0) (up to 1e-07 allowed)

I'll just remove this test case for now.

@will-cromar will-cromar merged commit 0cd6f10 into master Jan 3, 2024
21 checks passed
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Jan 3, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants