-
Notifications
You must be signed in to change notification settings - Fork 639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
32 bit optimizer update error despite gradients being the same #1185
Comments
@matthewdouglas @TimDettmers any insights? Thanks! |
Hi @Edenzzzz, Make sure that this chunk is contiguous as F.optimizer_update_32bit ultimately treats it as 1D. dist_low_rank_grad = torch.load("dist_low_rank_grad.pt").contiguous() I was able to reproduce your results, and after this change I believe I'm seeing the desired result. RTX 3060, CUDA 12.4, torch==2.2.2+cu121, bitsandbytes==0.43.1 |
Thanks a lot! This worked. The non-contiguous tensor came from torch.chunk and torch.distributed.all_gather. |
Yes, exactly. The C++ kernel assumes it's row-major and only knows the total number of elements.
That seems like a reasonable check to me, so a PR to add that sounds good! |
System Info
A100 GPU, torch 2.1, cuda 12.1, bitsandbytes 0.43.1
Reproduction
The tensors to be loaded are zipped here:
grads.zip
My result showing that most updates on the same grad chunk diverged
Expected behavior
This comes from adapting the Galore optimizer for Tensor parallel, when testing precision of the distributed and original optimizer.
Here the gradient is shared along dim 1 by tensor parallel, but the corresponding grad chunk clearly matches. However after the optim step the chunks are not exactly the same. I first doubted this is due to quantization statistics, but using 32 bit and disabling quantization stably leads to this bug.
The text was updated successfully, but these errors were encountered: