-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grad-Norm spike on transformer depth change #52
Comments
It is difficult to pinpoint the exact source of your instabilities. I recommend carefully examining the learning rate (LR) and weight decay, especially since these issues seem to arise during the warmup phase. For reference, we use 64 GPUs for training the 1B model (which closely resembles your configuration). This setup results in a total batch size in tokens calculated as: batch_size * seq_len * n_gpu, which is 4 * 4096 * 64 = 1,048,576 tokens per step. I am not sure about your exact configuration, but if you are using a different batch size, you may need to adjust the learning rate accordingly. Specifically, a smaller batch size typically requires a smaller learning rate. Also changing the model size, leads to some hyper parameters adjustment. For example the 7B has lr of 1e-3 and wd 0.1 but the 1B has an lr of 3e-3 and wd of 0.033. |
Thank you, this helps a lot! |
Thanks for open sourcing the code.
I was training some base transformer on fineweb_edu_10bt_shuffled data-set. and xmer_full is the config on the left, and xmer_210m_10L is the exact same code + execution, except that there are 10 layers now.
There is a grad-norm spike as indicated in the earlier figure, which is odd, are there some other configurations I should change? I would assume that xmer_210m_10L would do better on train-loss, but perhaps other than n_layers change, I am missing some other parameters to ensure stability? (The arch and code are the exact same)
Would appreciate pointers, in case there are other parameters that need to change if I increase number of layers, or pointers from your experience on why this spike happens (and hinders loss)
Thanks!
The text was updated successfully, but these errors were encountered: