-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimizer.step() takes too long in XLA/TPU #7923
Comments
@bdhirsh FYI, this is what we discussed last week. There are 2 issues I think
Pasting some of the chats between Brain and me to add more context
|
Another thing you could do for now is to use a larger batch size. Your issue here is you are tracing bound, but increasing the batch size will increase the device execution time while keep the tracing time constant. You will be able to utilize the device better. |
Thanks for the response! Yes increasing batch size would help a lot and I usually deal with large batch size. But sometimes for large pod I have to use small batch per core & that's when I found the problem. |
Hi, @JackCaoG, is that ok to assign this ticket to you? |
@bytetriper Could you check if the issue still happens using this branch? |
@bytetriper have you run this code on XLA:GPU? in case you did, do you see a similar performance issue there? |
I have tried running That said, I did notice that after my branch, the execution of
Note: in this case, @JackCaoG Could you help checking whether my branch fixes this issue? |
I will try to repo on my end, but I am busy these two days. |
Hi. No I haven't. But I think this is a TPU issue. XLA:TPU does not support many fancy AdamW tricks like foreach/fused but only a naive optimizer, which in my opinion contributes a lot to this issue |
@bytetriper Still, in the branch I mention above, I managed to reduce the latency of a few AdamW operations. Could you check if that solves the issue (even on TPU)? |
@JackCaoG do we plan to land this PR in 2.5? |
I think the fix should stay in nightly. |
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their respective inplace versions). These functions only had refs implementations, which was being the root cause of a significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA backend. Running the meta functions resulted in the following improvements: - `lerp` calls: 1,550ms to 140ms (10x) - `addcdiv` calls: 640ms to 350ms (1.8x) - `addcmul` calls: 620ms to 300ms (2.05x) [1]: pytorch/xla#7923 ghstack-source-id: f08891d8ecfd949a298ab6603534297caaf9deaf Pull Request resolved: #136909
🐛 Bug
Common optimizer like Adam/AdamW takes too long in
optimizer.step()
for small models. I tested a small ViT with 5.8M parameters andtorch.optim.AdamW
takes ~.2s for a single step.To Reproduce
Below is a minimal example that can be run by
python test.py
. It trains a 5.8M ViT on fake image data with a local batch size of 8. The average training speed measured on a TPU v4-8 is 4.2it/s, which is much lower than expected.A one step trace screenshot shows the major time is spend on
optimizer.step()
:Expected behavior
A much faster speed up for
optimizer.step
Environment
Additional context
As suggested by @JackCaoG setting
XLA_DISABLE_FUNCTIONALIZATION=1
mitigates the problem with ~5 time speed up. Would love to see more detailed info about this.Also only
_single_tensor_adamw
is supported in XLA/TPU now. Any plan to support optimization like fused/foreach/CPUAdamW ?The text was updated successfully, but these errors were encountered: