Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grad-Norm spike on transformer depth change #52

Closed
akhauriyash opened this issue Nov 7, 2024 · 3 comments
Closed

Grad-Norm spike on transformer depth change #52

akhauriyash opened this issue Nov 7, 2024 · 3 comments

Comments

@akhauriyash
Copy link

image

image

Thanks for open sourcing the code.

I was training some base transformer on fineweb_edu_10bt_shuffled data-set. and xmer_full is the config on the left, and xmer_210m_10L is the exact same code + execution, except that there are 10 layers now.

There is a grad-norm spike as indicated in the earlier figure, which is odd, are there some other configurations I should change? I would assume that xmer_210m_10L would do better on train-loss, but perhaps other than n_layers change, I am missing some other parameters to ensure stability? (The arch and code are the exact same)

image

Would appreciate pointers, in case there are other parameters that need to change if I increase number of layers, or pointers from your experience on why this spike happens (and hinders loss)

Thanks!

@akhauriyash akhauriyash changed the title Grad-Norm spike on transformer configuration change Grad-Norm spike on transformer depth change Nov 7, 2024
@mathuvu
Copy link
Contributor

mathuvu commented Nov 18, 2024

It is difficult to pinpoint the exact source of your instabilities. I recommend carefully examining the learning rate (LR) and weight decay, especially since these issues seem to arise during the warmup phase. For reference, we use 64 GPUs for training the 1B model (which closely resembles your configuration). This setup results in a total batch size in tokens calculated as: batch_size * seq_len * n_gpu, which is 4 * 4096 * 64 = 1,048,576 tokens per step.

I am not sure about your exact configuration, but if you are using a different batch size, you may need to adjust the learning rate accordingly. Specifically, a smaller batch size typically requires a smaller learning rate. Also changing the model size, leads to some hyper parameters adjustment. For example the 7B has lr of 1e-3 and wd 0.1 but the 1B has an lr of 3e-3 and wd of 0.033.

@mathuvu
Copy link
Contributor

mathuvu commented Nov 18, 2024

For the 1B training, when examining the gradient norm, I did not observe any significant spikes at the beginning of the run.

Here is the grad norm for the 1B during the training:
Screenshot 2024-11-18 at 16 04 01

@akhauriyash
Copy link
Author

Thank you, this helps a lot!
It might have something to do with the dataset, changing the dataset from fineweb to dclm made the issue go away

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants