Skip to content

Commit

Permalink
[Trainer] Make sure shown loss in distributed training is correctly a…
Browse files Browse the repository at this point in the history
…veraged over all workers (#13681)

* push

* improve tr loss gather
  • Loading branch information
patrickvonplaten authored Sep 26, 2021
1 parent 044eff5 commit 91df455
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,10 @@ def _load_state_dict_in_model(self, state_dict):
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

Expand Down

0 comments on commit 91df455

Please sign in to comment.