-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How does fake tensor works with tensor subclass in torch.compile? #136287
Comments
I tried running the repro with compile turned off:
And I get a similar shape mismatch error from DTensor:
So the tutorial seems like it is wrong in eager mode? If you compile a model that is supposed to give you an error about incorrect shapes, then the fake tensor error you are seeing is expected (FakeTensor will give you a similar error when it does fake tensor propagation) |
I can repro this with |
If it also repros without compile at all (e.g. removing compile and setting |
OK, it turns out this specific issue is because the transpose op is implemented as an inplace op right now, that's why it fails the second time we run it. I just updated the PR and now there is a new error
|
Oh, that latest error should actually be fixed by #136266 (comment) |
🐛 Describe the bug
I'm working on an example for quantized tensor subclass + DTensor (tensor parallel) + compile: pytorch/ao#785
the test works with eager mode, but failed due to a shape mismatch in compile right now.
input shape: (128, 1024), linear weight shape: (512, 1024) (out * in)
Errors in torch.mm op with fake tensor:
transpose implementation looks like the following:
It seems that the fake tensor did not pick up the changes to the shape in this case.
Repro:
with-proxy torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py
Versions
main
cc @ezyang @albanD @chauhang @penguinwu @eellison @zou3519 @bdhirsh
The text was updated successfully, but these errors were encountered: