-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix embedding_backward_dense decomp with broadcasting #95499
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95499
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4919634: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. It might be worth figuring out why the OpInfo inputs didn't exercise this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is right, thanks for the fix @bdhirsh
Fixes #95182 cc ngimel for another decomp fix. For this one, I tried auditing the CPU and CUDA kernels for `embedding_backward_dense` and just could not figure out where the `unsqueeze(1)` was supposed to be coming from. In the failing example, our tensor shapes are `(2, 4, 3)` and `(2, 4)`, and so I just assumed that the existing decomp had a typo - we should be unsqueezing the last dim, instead of dim index 1. That fixes the repro, and the existing decomp + meta tests appear to be passing. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
ghstack-source-id: dfea7c497fde1ad0c3030bc7bc6fe6e790de8954 Pull Request resolved: #95499
Fixes #95182 cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire ngimel for another decomp fix. For this one, I tried auditing the CPU and CUDA kernels for `embedding_backward_dense` and just could not figure out where the `unsqueeze(1)` was supposed to be coming from. In the failing example, our tensor shapes are `(2, 4, 3)` and `(2, 4)`, and so I just assumed that the existing decomp had a typo - we should be unsqueezing the last dim, instead of dim index 1. That fixes the repro, and the existing decomp + meta tests appear to be passing. cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Fixes pytorch/pytorch#95182 Pull Request resolved: pytorch/pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes pytorch/pytorch#95182 Pull Request resolved: pytorch/pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes pytorch/pytorch#95182 Pull Request resolved: pytorch/pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes pytorch/pytorch#95182 Pull Request resolved: pytorch/pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes pytorch/pytorch#95182 Pull Request resolved: pytorch/pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes pytorch#95182 Pull Request resolved: pytorch#95499 Approved by: https://github.com/ezyang, https://github.com/ngimel
Fixes #95182
cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @ngimel for another decomp fix. For this one, I tried auditing the CPU and CUDA kernels for
embedding_backward_dense
and just could not figure out where theunsqueeze(1)
was supposed to be coming from. In the failing example, our tensor shapes are(2, 4, 3)
and(2, 4)
, and so I just assumed that the existing decomp had a typo - we should be unsqueezing the last dim, instead of dim index 1. That fixes the repro, and the existing decomp + meta tests appear to be passing.Stack from ghstack (oldest at bottom):
cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire