-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data-type promotion to stack
.
#7091
Conversation
@@ -3158,8 +3158,12 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, | |||
|
|||
at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { | |||
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); | |||
at::ScalarType result_type = at::native::result_type(tensors); | |||
std::vector<at::Tensor> c_tensors(tensors.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is stack expecting input tensor to be CPU? std::vector<at::Tensor> c_tensors
will return a list of tenosrs on CPU right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. Unless I'm missing something, they are casted tensors, on XLA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then I am abit confused. Reading your code, you init the c_tensors
vector which I assume they will be cpu tensors since you didn;t provide the device type. In the later code you only update the dtype of these c_tensors
, I don't know when are they moved to the XLA device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a summary of what this code is doing: considering the arguments tensors
(a list of XLA tensors) and dim
, the function:
- Computes the common data-type of all tensors:
result_type
- Converts each tensor to the common data-type, storing the result in
c_tensors
(as in "cast tensors") - Calls
tensor_methods::stack
with the casted tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. transform
is called with tensors.begin()
..
57352fb
to
5fbcdd9
Compare
Fix: #7083
This PR adds data-type promotion to
stack
operation. Previously, there was none. So, the kernel implicitly expected the arguments to be of the same data-type. This might not be the case when using AMP.cc @miladm @JackCaoG