Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Divergence in pretraining BERT large #244

Closed
formiel opened this issue Nov 21, 2019 · 6 comments
Closed

Divergence in pretraining BERT large #244

formiel opened this issue Nov 21, 2019 · 6 comments

Comments

@formiel
Copy link

formiel commented Nov 21, 2019

I encountered a very weird problem in which the validation perplexity and MLM accuracy were extremely bad after the first epoch of training. Specifically:

  • Epoch 0: MLM accuracy 62.23 - Perplexity 6.43.
  • Epoch 1 (resumed from the checkpoint of Epoch 0): MLM accuracy 4.34 - Perplexity 1452.28.

An iteration over all of my training data takes 2.5 epochs and one epoch would take 10 hours to finish. I trained this model on multiple nodes using Slurm. Due to the time limit of my server, I only submit jobs with the requested time around 10 hours so that 1 epoch is run and saved in each job. The next job will resume the checkpoint from its previous job and continue training. However, after the first job, I observed a very bad performance in validation perplexity and MLM accuracy as above.

I think it's very weird and I suspect there maybe some kind of changes in configuration between the nodes in my server that caused the nodes not synchronised and collected gradients properly like this...

However, I just want to ask you to understand how the model on the master GPU collects gradients signal from the model replicas on the remaining devices. Is this process handled by Slurm or PyTorch distributed package or both? Besides, is this normal if the master GPU save checkpoint while some of the remaining devices are still in evaluation mode and hasn't finished the scoring part on validation and test set yet?

Thank you so much in advance for your help!

@formiel
Copy link
Author

formiel commented Nov 28, 2019

I found the same issue of divergence in #140. I trained this BERT large model using the training configuration recommended in RoBERTa paper, which is optimizer="adam_inverse_sqrt,lr=0.0004,warmup_updates=30000,beta1=0.9,beta2=0.98,weight_decay=0.01,eps=0.000001". The model is trained on 512 GPUs (128 nodes) with a batch size of 8 and accumulate gradients of 2, so I think it would make a total batch size of 8192.

I wonder if my actual effective batch size is much smaller than 8192 since each model replica may update the gradients on its own and the gradients of the checkpoint is gathered in the end by the master GPU by averaging or in some other way. So my effective batch size could be just 16 and this causes the instability issue due to the model is very large (24 layers).

Could you please help me explain how the gradients is gathered and how to get the actual batch size if the model is trained on multiple nodes and multiple GPUs?

Thank you so much!

@glample
Copy link
Contributor

glample commented Nov 28, 2019

Hi,

The gradients are accumulated and averaged at the end. So if you train with 32 GPUs instead of 8, you multiply the effective batch size by 4. An equivalent way to do this if you don't have 4 nodes is to do: --accumulate_gradients 4 which will do 4 forward / backward calls before making an update, which is equivalent to training with 4x larger batches.

@netw0rkf10w
Copy link

@glample In the case above, the effective batch size is 8192 yet training is not stable right at the second epoch, which is very surprising to me...

@glample
Copy link
Contributor

glample commented Nov 28, 2019

Do you have the same issue with 12 or 16 layers? We also observed diverging issues with 24 layers, but usually they tend to appear at a later stage of training. One thing that helped was to use a L2 regularization on the logits of the self-attention (if you print these you may observe that they converge to huge values). Another was to use a larger epsilon in the layer norm (something around 1e-6, for instance). These things should help, but there is no guarantee that the model will never diverge.

There is this paper that came out recently: https://openreview.net/pdf?id=SylO2yStDr that maybe you can try? Idea is very simple to implement: at training time, each transformer layer is dropped with (for instance) 10% probably. This apparently makes the training much more stable, and only takes a few lines to implement, but we have not tried it with this codebase yet.

@formiel
Copy link
Author

formiel commented Dec 28, 2019

Sorry for the late reply, I couldn't find the time to advance on this earlier.

@glample @aconneau Thanks again for your active support (I thanked you in our recent Flaubert submission, c.f. Acknowledgements section).

Do you have the same issue with 12 or 16 layers?

I did not train the 16 layer model, but I had no convergence issue at all with the 12 layer (i.e. base) model.

Another was to use a larger epsilon in the layer norm (something around 1e-6, for instance).

Thanks. I didn't notice that the default value was very small (1e-12). The one of fairseq (also used in RoBERTa) is 1e-5. Since I followed RoBERTa's configurations to train my models, I think using 1e-5 or 1e-6 would be better.

There is this paper that came out recently: https://openreview.net/pdf?id=SylO2yStDr that maybe you can try?

Thanks a lot for the reference! I have implemented LayerDrop, it looks like this in src/model/transformer.py, line 391:

        # transformer layers
        for i in range(self.n_layers):
            # LayerDrop
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):
                continue
            # self attention
            ...

It seems to work. However, training speed decreased significantly:

  • Without LayerDrop: 48.04 sent/s - 3709.03 words/s
  • With LayerDrop rate 0.2: 28.51 sent/s - 2166.86 words/s

This is very strange because theoretically adding LayerDrop should be faster than no LayerDrop. Could you please tell me if you have an idea on this?

Thank you again for your help!

@formiel formiel changed the title Unusual increases in perplexity (and decrease in MLM accuracy) in pretraining BERT Divergence in pretraining BERT large Dec 28, 2019
@formiel
Copy link
Author

formiel commented Dec 29, 2019

It seems that the problem lies in the backward passes or in the gradient gathering steps, because I try timing the forward passes and there was indeed a speedup.

I added the following code to count the number of layers and the running time of each forward pass:

        # transformer layers
        num_skips = 0
        start = time.time()
        for i in range(self.n_layers):
            # LayerDrop
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):
                num_skips += 1
                continue
            # self attention
            ...
            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        logger.info('{} layers took {}s'.format(self.n_layers - num_skips, time.time() - start))

Training BERT base on one node with 2 GPUs, I obtained:

Without LayerDrop:

INFO - 12/29/19 01:18:32 - 0:00:21 - 12 layers took 0.028100252151489258s
INFO - 12/29/19 01:18:33 - 0:00:21 - 12 layers took 0.02810955047607422s
INFO - 12/29/19 01:18:33 - 0:00:21 - 12 layers took 0.028097867965698242s
INFO - 12/29/19 01:18:33 - 0:00:22 - 12 layers took 0.028098344802856445s
INFO - 12/29/19 01:18:34 - 0:00:22 - 12 layers took 0.028106689453125s
INFO - 12/29/19 01:18:34 - 0:00:22 - 35 - 48.33 sent/s - 3677.72 words/s - MLM-fr: 13.9210 - - model LR: 9.2486e-07
INFO - 12/29/19 01:18:34 - 0:00:22 - 12 layers took 0.028125286102294922s
INFO - 12/29/19 01:18:34 - 0:00:23 - 12 layers took 0.028150558471679688s
INFO - 12/29/19 01:18:35 - 0:00:23 - 12 layers took 0.028117656707763672s
INFO - 12/29/19 01:18:35 - 0:00:23 - 12 layers took 0.028110265731811523s
INFO - 12/29/19 01:18:35 - 0:00:24 - 12 layers took 0.02816009521484375s
INFO - 12/29/19 01:18:36 - 0:00:24 - 40 - 48.21 sent/s - 3586.95 words/s - MLM-fr: 13.0384 - - model LR: 1.0498e-06

With LayerDrop = 0.5:

INFO - 12/29/19 01:11:54 - 0:00:24 - 5 layers took 0.011696815490722656s
INFO - 12/29/19 01:11:55 - 0:00:25 - 4 layers took 0.009411334991455078s
INFO - 12/29/19 01:11:55 - 0:00:25 - 8 layers took 0.01871204376220703s
INFO - 12/29/19 01:11:55 - 0:00:26 - 7 layers took 0.016376256942749023s
INFO - 12/29/19 01:11:56 - 0:00:26 - 6 layers took 0.014084577560424805s
INFO - 12/29/19 01:11:56 - 0:00:26 - 35 - 36.86 sent/s - 2797.94 words/s - MLM-fr: 16.6314 - - model LR: 9.2486e-07
INFO - 12/29/19 01:11:56 - 0:00:27 - 6 layers took 0.014055728912353516s
INFO - 12/29/19 01:11:57 - 0:00:27 - 6 layers took 0.01404428482055664s
INFO - 12/29/19 01:11:57 - 0:00:27 - 7 layers took 0.016412734985351562s
INFO - 12/29/19 01:11:58 - 0:00:28 - 6 layers took 0.014067649841308594s
INFO - 12/29/19 01:11:58 - 0:00:28 - 8 layers took 0.01872563362121582s
INFO - 12/29/19 01:11:58 - 0:00:29 - 40 - 35.64 sent/s - 2715.41 words/s - MLM-fr: 16.4965 - - model LR: 1.0498e-06

The forward passes are clearly faster with LayerDrop, while the final speed is slower (36 sent/s vs 48 sent/s). And this is even worse in multi-node training.

I also tried LayerDrop in another implementation (following this PyTorch official tutorial), and did not have this issue (LayerDrop is faster overall).

I wonder if this is related to the use of Apex for distributed training...

@formiel formiel closed this as completed Jul 26, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants