-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for _unsafe_index
.
#5707
Conversation
@@ -626,6 +626,13 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, | |||
return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); | |||
} | |||
|
|||
at::Tensor XLANativeFunctions::_unsafe_index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know when will this method being used(how does it get dispatched to this) and create a test for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to adding the tests. For more info, the cpp unit tests can verify that this new op you lowered is properly invoked via metric assertions as such https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_1.cpp#L35.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get it as a node in the FX graph generated from AOTAutograd (using openxla
backend). I will investigate it further and add tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peterbell10 added this function. These functions are an artefact to signal the compiler to avoid performing any boundary checks on the indices that they use. These are used in decompositions where we know that the indices will be in the right range, like the backward of functions, or a decomposition where we know that the inputs are bounded to the right range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that you also have its cousin _unsafe_index_put
Thanks for adding the cpp tests! For some reason, I can see that the CI fails at build stage but I can't seem to see any logs. I've just re-triggered the CI, let's see if this run is good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This sounds like a bug, either the fallback creates a CPU tensor with XLA dispatch keys, or the falback mechanism doesn't handle lists of tensors and some of the arguments were not moved to CPU. |
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
This PR adds support for
_unsafe_index
by aliasingindex
.Problem:
_unsafe_index
was hitting the CPU fallback execution flow, which moved the tensors to CPU (includingself
)at::index
self
lives on CPU)This problem affected a few Torchbench benchmarks running with
openxla
backend:Super_SloMo
,doctr_det_predictor
,pytorch_unet
, ...