Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward XLATensorImpl::is_contiguous_custom to TensorImpl. #8032

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

ysiraichi
Copy link
Collaborator

This PR fixes #7998. Instead of always returning true, we forward this call to the base class TensorImpl::is_contiguous_custom().

The reason is that after pytorch/pytorch#135498 is merged, XLA tensors' metadata might stop reflecting on the actual XLA storage. Which means that the tensors' strides might not always be contiguous. Whenever that happens, tensor.is_contiguous() call should be consistent with the tensors' strides.

cc @miladm @JackCaoG @alanwaketan

@JackCaoG
Copy link
Collaborator

test failure seems real?

@JackCaoG
Copy link
Collaborator

yea test_memory_format_preserved_after_permute_xla still failing.

@@ -1068,7 +1068,7 @@ def test_backward_optimization_barrier(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
self.assertIn(
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
'%opt-barrier.38 = (f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) %tuple.37)',
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed for fixing a CI failure in pytorch/pytorch#135237. @JackCaoG let me know if you think this should not be happening.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file are needed because it seems that PyTorch is running these tests with Python 3.8. Thus, the subscript operator is not allowed for types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully pytorch/pytorch#135278 can be merged and we don't have this issue anymore in the future...

@miladm
Copy link
Collaborator

miladm commented Sep 23, 2024

@JackCaoG do we want to add this PR to 2.5 release? (knowing it has some upstream dependencies to consider - @ysiraichi please reference the dependencies for clarity)

@JackCaoG
Copy link
Collaborator

no, I don't want to add features to 2.5 releases at this point.

// If functionalization is disabled, the tensors' metadata aren't being
// updated w.r.t. the output of meta functions. Therefore, we fallback to the
// old behavior returning true, always.
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just add a IsFunctionalizationDisabled(use a static bool to store the value of this env var) in https://github.com/pytorch/xla/blob/ea8c47a345a29c6a1b1bdf4ee38a9159b07a980f/torch_xla/csrc/xla_graph_executor.h or in tensor.h instead of keep getting the from the env var

@ysiraichi
Copy link
Collaborator Author

I don't think the approach in this PR will be enough for propagating the metadata. I think I will have to modify the functionalizeFallback function instead, so that all operations propagate the metadata correctly. This will probably add more overhead to the execution, though. @JackCaoG what do you think?

@JackCaoG
Copy link
Collaborator

would be good to measure those overheads, my experience is that most of the C++ ops are pretty fast except creating large objects and calculating hash. It might be not too bad in this case too.

@ysiraichi
Copy link
Collaborator Author

The reason why I'm asking this is because of #7923. Basically, what I'm saying is that we will have to call the meta functions of every operation that goes through the functional fallback function (e.g. add(), clone(), etc). That might end up calling a Python meta function.

@JackCaoG what do you think?

@JackCaoG
Copy link
Collaborator

I am already concern about the python meta function overhead we have today so ideally I don't want to introduce more. Do you mind give it a try and see how many tracing time overhad does it introduce?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants