Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data-type promotion to stack. #7091

Merged
merged 2 commits into from
May 23, 2024

Conversation

ysiraichi
Copy link
Collaborator

Fix: #7083

This PR adds data-type promotion to stack operation. Previously, there was none. So, the kernel implicitly expected the arguments to be of the same data-type. This might not be the case when using AMP.

cc @miladm @JackCaoG

@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 00:42
@@ -3158,8 +3158,12 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self,

at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType result_type = at::native::result_type(tensors);
std::vector<at::Tensor> c_tensors(tensors.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is stack expecting input tensor to be CPU? std::vector<at::Tensor> c_tensors will return a list of tenosrs on CPU right?

Copy link
Collaborator Author

@ysiraichi ysiraichi May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Unless I'm missing something, they are casted tensors, on XLA.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I am abit confused. Reading your code, you init the c_tensors vector which I assume they will be cpu tensors since you didn;t provide the device type. In the later code you only update the dtype of these c_tensors, I don't know when are they moved to the XLA device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a summary of what this code is doing: considering the arguments tensors (a list of XLA tensors) and dim, the function:

  1. Computes the common data-type of all tensors: result_type
  2. Converts each tensor to the common data-type, storing the result in c_tensors (as in "cast tensors")
  3. Calls tensor_methods::stack with the casted tensors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. transform is called with tensors.begin()..

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-stack-dtype-promotion branch from 57352fb to 5fbcdd9 Compare May 22, 2024 14:18
@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 18:54
@ysiraichi ysiraichi merged commit a299f33 into master May 23, 2024
20 checks passed
qihqi pushed a commit that referenced this pull request May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torchbench] timm_efficientdet training failing on non-dynamo.
2 participants