-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[benchmarks] Fix AMP setup for torchbench models. #7067
Conversation
Confirmed it also fixes #6833. |
Hmm according to https://github.com/pytorch/xla/blob/master/docs/amp.md we should be able to use |
That document is correct. Problem is that I didn't notice XLA:CUDA is supposed to run with CUDA autocast, i.e. |
# https://github.com/pytorch/xla/issues/6511 | ||
if self.is_accelerator_cuda(): | ||
# For inductor and XLA:CUDA, we use CUDA autocast. | ||
autocast = torch.cuda.amp.autocast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess torch.cuda.amp.autocast
is the same as torch.amp.autocast("cuda")
?
# https://github.com/pytorch/xla/issues/6511 | ||
if self.is_accelerator_cuda(): | ||
# For inductor and XLA:CUDA, we use CUDA autocast. | ||
autocast = torch.cuda.amp.autocast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to set kwargs["device_type"] = "xla"
for XLA:GPU case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. torch.cuda.amp.autocast
already does that.
Fix: #6556 (and, possibly #6833)
This PR fixes the benchmarks script when running with AMP. Previously, we were calling
torch.amp.autocast(..., device_type="xla")
for both XLA:CUDA and XLA:TPU. However, we should be usingtorch.cuda.amp.autocast
for XLA:CUDA (see this for more details).Context: after #6518,
Super_Slomo
inference started being run using AMP. However, due to #6511, that PR tried to mimictorch_xla.amp.autocast
behavior, usingtorch.amp.autocast
.cc @miladm @JackCaoG @vanbasten23 @zpcore