Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make XLAShardedTensor use default __torch_function__ #6625

Merged
merged 4 commits into from
Mar 2, 2024

Conversation

JackCaoG
Copy link
Collaborator

Companion pr for pytorch/pytorch#120632 (comment), through I have to admit I don't fully understand why XLA test will fail without this..

@JackCaoG JackCaoG requested a review from yeounoh February 27, 2024 20:22
@albanD
Copy link

albanD commented Feb 27, 2024

Thanks for the quick PR!
This is completely fine to land this before the PR in PyTorch as this is the default behavior and is a no-op right now.
Just let me know when this propagates and I can rebase the PyTorch side PR to include this change.

@albanD
Copy link

albanD commented Feb 27, 2024

I have to admit I don't fully understand why XLA test will fail without this..

Me neither.
The main change I am adding in my PR is disabling this default handler. The main thing it does is automatically wrap anything your torch_dispatch handler returns into your subclass. Meaning that if your torch_dispatch handler returns a plain Tensor, then the torch_function handler will wrap it into your subclass. This is usually not the expected behavior though, hence the fact that we're changing that.

But from a quick look at what you do, it shouldn't change anything here...

@JackCaoG
Copy link
Collaborator Author

Hmm dynamo started to complain after adding this torch_dispatch thingy..

@albanD
Copy link

albanD commented Feb 27, 2024

Ho I see the same issue happening on Int16Tensor in the other PR. Let me look into that.

Copy link

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this seems to fix it from this basic experience haha pytorch/pytorch#120799

Will investigate the Dynamo failure as well but this should make this PR green!

@@ -108,6 +108,9 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
r.global_tensor = elem.detach() if r.requires_grad else elem
return r

def __torch_function__(cls, func, types, args=(), kwargs=None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __torch_function__(cls, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let me update and give it another try

@JackCaoG
Copy link
Collaborator Author

@albanD ok great, seems like this CI will fix the failure. Let me remove the torch pin and merge it to unblock you.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Mar 1, 2024

@yeounoh can you take a look at this pr and merge it to unblock Alban?

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks @albanD

@yeounoh yeounoh merged commit 5a113af into master Mar 2, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants