Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Concern about the implementation of gradient_clipping during fp16 training #5413

Closed
3 tasks
YKX-A opened this issue Sep 20, 2021 · 4 comments
Closed
3 tasks
Assignees
Labels

Comments

@YKX-A
Copy link
Contributor

YKX-A commented Sep 20, 2021

Checklist

  • [ x] I have verified that the issue exists against the main branch of AllenNLP.
  • [x ] I have read the relevant section in the contribution guide on reporting bugs.
  • [x ] I have checked the issues list for similar or identical bug reports.
  • [x ] I have checked the pull requests list for existing proposed fixes.
  • [ x] I have checked the CHANGELOG and the commit log to find out if the bug was already fixed in the main branch.
  • [ x] I have included in the "Description" section below a traceback from any exceptions related to this bug.
  • [ x] I have included in the "Related issues or possible duplicates" section beloew all related issues and possible duplicate issues (If there are none, check this box anyway).
  • I have included in the "Environment" section below the name of the operating system and Python version that I was using when I discovered this bug.
  • I have included in the "Environment" section below the output of pip freeze.
  • I have included in the "Steps to reproduce" section below a minimally reproducible example.

Description

I have a question about the implementation of grad_clipping in gradient_descent_trainer(here), when opening up use_amp option.
The current implementation is to register a hook, which clips gradient immediately. (here)
However, I think this specific implementation of gradient clipping has a problem when combined with mixed-precision training.

EXAMPLE:
(SETTING: gradient_clipping=2.0)

  1. Supposing the gradient of param P is 1.0.
  2. Note that fp16 training needs to scale the loss larger before Back Propagation, say factor = 16, but during Back Propagation, the hook would immediately clip the gradient of param P from 1 * 16 = 16.0 to 2.0.
  3. At last, we unscale the fp32 gradient of param P 2.0/16 = 0.125 (But it should be 1.0, right?)

Solution
In short, it might be more proper to use torch.nn.utils.clip_grad_value_ instead of reristering a hook (where to clip gradient? I think it could follow the clip_grad_norm in Trainer, in a similar fashion and position, after self._scaler.unscale_(self.optimizer)).
Pytorch officially has an advice about gradient clipping in fp16 training, see here (though the example is in fact about clip gradient norm).

Python traceback:

Related issues or possible duplicates

  • None

Environment

OS:

Python version:

Output of pip freeze:

Steps to reproduce

Example source:

@YKX-A YKX-A added the bug label Sep 20, 2021
@epwalsh epwalsh self-assigned this Sep 24, 2021
@epwalsh
Copy link
Member

epwalsh commented Sep 24, 2021

Hi @YKX-A, thanks for the details, I think you are right. Would you like to make a pull request to change the behavior?

@YKX-A
Copy link
Contributor Author

YKX-A commented Sep 25, 2021

I would like to list my changes below, but I have some trouble in determining necessary changes (I believe your team is more sutable to make the decision than me 😄):
- handling _ddp_wrapped_model.
- a more consistent way to handle grad_norm and gradient_clip (because both of the two functions need to unscale the gradient).

My changes are listed below:
- I add a new method clip_gradient() and perform it after the rescale_gradients() method in Trainer.
- Since _scaler.unscale_ can not be called twice per optimization step (see here), I did some repeat call detection following the approach
in pytorch (here)

class GradientDescentTrainer(Trainer):
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        ...
        for batch in batch_group:
            ...
        if len(batch_group_outputs) <= 0:
            ...
        
        train_loss += batch_loss

        batch_grad_norm = self.rescale_gradients()
        # new line here ==============================
        self.clip_gradient()

    # new method ==============================
    def clip_gradient(self):
            if isinstance(self._grad_clipping, float):
                # 1. We have to unscale the gradient before clipping
                if self._scaler is not None:
                    UNSCALED = 1

                    optimizer_state = self._scaler._per_optimizer_states[id(optimizer)]
                    # 2. The `unscale_` should't be performed more than once per optimizer per step call,
                    # so we perform `unscale_` if it has not already been called.
                    if optimizer_state["stage"] is not UNSCALED:
                        self._scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_value_([p for p in self.model.parameters() if p.grad is not None],
                                                    self._grad_clipping)

@epwalsh
Copy link
Member

epwalsh commented Sep 29, 2021

Hey @YKX-A, this approach seems fine to me. Please start the PR whenever you're ready and give me a ping when you do.

@github-actions
Copy link

@epwalsh this is just a friendly ping to make sure you haven't forgotten about this issue 😜

@YKX-A YKX-A closed this as completed Oct 14, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

2 participants