diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 6f282001c0..fb742c7dd9 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -90,3 +90,12 @@ def aggregate_logging_outputs(logging_outputs): if sample_size != ntokens: agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) if ntokens > 0 else 0. return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 1996e9edf3..0fbac41c0c 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -65,3 +65,12 @@ def aggregate_logging_outputs(logging_outputs): if sample_size != ntokens: agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 9645ec47a6..3cab013b23 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -37,3 +37,12 @@ def forward(self, model, sample, reduce=True): def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" raise NotImplementedError + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return False diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index a54c7516c9..ec50834342 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs): 'nsentences': nsentences, 'sample_size': sample_size, } + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py index 2cb5621498..1e2cb8e335 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs): 'nsentences': nsentences, 'sample_size': sample_size, } + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py index fe3c7bf2a6..b24e19ea01 100644 --- a/fairseq/criterions/legacy_masked_lm.py +++ b/fairseq/criterions/legacy_masked_lm.py @@ -145,3 +145,12 @@ def aggregate_logging_outputs(logging_outputs): 'sample_size': sample_size, } return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index eb2fcf3d3a..4a7ea8b3c4 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -79,3 +79,12 @@ def aggregate_logging_outputs(logging_outputs): 'sample_size': sample_size, } return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py index 6c4c81702c..379bed9bdf 100644 --- a/fairseq/criterions/nat_loss.py +++ b/fairseq/criterions/nat_loss.py @@ -167,3 +167,12 @@ def aggregate_logging_outputs(logging_outputs): ) return results + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py index 11678c8ca3..d5450ed946 100644 --- a/fairseq/criterions/sentence_prediction.py +++ b/fairseq/criterions/sentence_prediction.py @@ -96,3 +96,12 @@ def aggregate_logging_outputs(logging_outputs): if sample_size != ntokens: agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/sentence_ranking.py b/fairseq/criterions/sentence_ranking.py index 5bb2119dba..c90e872e87 100644 --- a/fairseq/criterions/sentence_ranking.py +++ b/fairseq/criterions/sentence_ranking.py @@ -115,3 +115,12 @@ def aggregate_logging_outputs(logging_outputs): if sample_size != ntokens: agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) return agg_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return True diff --git a/fairseq/options.py b/fairseq/options.py index 005d45d75a..eef572a04f 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -345,8 +345,7 @@ def add_distributed_training_args(parser): help='disable unused parameter detection (not applicable to ' 'no_c10d ddp-backend') group.add_argument('--fast-stat-sync', default=False, action='store_true', - help='Enable fast sync of stats between nodes, this hardcodes to ' - 'sync only some default stats from logging_output.') + help='[deprecated] this is now defined per Criterion') # fmt: on return group diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 34ee1fd3fa..39f57caf5b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -53,14 +53,19 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non self._num_updates = 0 self._optim_history = None self._optimizer = None - self._prev_grad_norm = None self._wrapped_criterion = None self._wrapped_model = None - # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes. - # It is less flexible and syncs only the default stats. - self._all_reduce_list = [0.0] * 6 - self.fast_stat_sync = args.fast_stat_sync + if self.cuda and args.distributed_world_size > 1: + self._grad_norm_buf = torch.cuda.DoubleTensor(args.distributed_world_size) + else: + self._grad_norm_buf = None + + if args.fast_stat_sync: + utils.deprecation_warning( + '--fast-stat-sync is deprecated. If needed, please update your ' + 'Criterion to define the logging_outputs_can_be_summed() method.' + ) self.init_meters(args) @@ -294,7 +299,7 @@ def train_step(self, samples, dummy_batch=False, raise_oom=False): self.meters["train_wall"].start() # forward and backward pass - logging_outputs, sample_sizes, ooms = [], [], 0 + logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: @@ -323,22 +328,13 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward - loss, sample_size, logging_output = self.task.train_step( + loss, sample_size_i, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad ) if not ignore_grad: logging_outputs.append(logging_output) - sample_sizes.append(sample_size) - - if self.fast_stat_sync: - self._all_reduce_list[0] += sample_size - self._all_reduce_list[1] += logging_output.get( - "nsentences", 0.0 - ) - self._all_reduce_list[2] += logging_output.get("loss", 0.0) - self._all_reduce_list[3] += logging_output.get("nll_loss", 0.0) - self._all_reduce_list[4] += logging_output.get("ntokens", 0.0) + sample_size += sample_size_i except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) @@ -353,9 +349,6 @@ def maybe_no_sync(): else: raise e - if self.fast_stat_sync: - self._all_reduce_list[5] += ooms - if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) @@ -363,48 +356,10 @@ def maybe_no_sync(): return None # gather logging outputs from all replicas - if self.fast_stat_sync: - # rework all_gather_list - all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list) - if self._sync_stats(): - torch.distributed.all_reduce(all_reduce_list_tensor) - # Normalize loss and nll_loss by "sample_size" - # and convert to log base 2 - all_reduce_list_tensor[2:4].div_( - (all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2]))) + if self._sync_stats(): + logging_outputs, sample_size, ooms = self._aggregate_logging_outputs( + logging_outputs, sample_size, ooms, ) - self._all_reduce_list = all_reduce_list_tensor.tolist() - logging_output = {} - [ - sample_size, - logging_output["nsentences"], - logging_output["loss"], - logging_output["nll_loss"], - logging_output["ntokens"], - ooms, - ] = self._all_reduce_list - elif self._sync_stats(): - logging_outputs, sample_sizes, ooms, prev_norms = zip( - *distributed_utils.all_gather_list( - [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], - max_size=getattr(self.args, 'all_gather_list_size', 16384), - ) - ) - logging_outputs = list(chain.from_iterable(logging_outputs)) - sample_sizes = list(chain.from_iterable(sample_sizes)) - ooms = sum(ooms) - - if not self.args.use_bmuf: - norms = [norm for norm in prev_norms if norm is not None] - if not ( - all(norm == norms[0] for norm in norms) - or all(math.isnan(norm) or math.isinf(norm) for norm in norms) - ): - raise RuntimeError( - "Fatal error: gradients are inconsistent between workers. " - "Try --ddp-backend=no_c10d, which is a more robust but " - "slightly slower DDP implementation." - ) self.meters["oom"].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): @@ -412,12 +367,10 @@ def maybe_no_sync(): self.zero_grad() return None - if not self.fast_stat_sync: - # aggregate logging outputs and sample sizes - logging_output = self.task.aggregate_logging_outputs( - logging_outputs, self.get_criterion() - ) - sample_size = sum(sample_sizes) + # aggregate logging outputs and sample sizes + logging_output = self.task.aggregate_logging_outputs( + logging_outputs, self.get_criterion() + ) if not all(k in logging_output for k in ["ntokens", "nsentences"]): raise Exception( @@ -442,7 +395,9 @@ def maybe_no_sync(): # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) - self._prev_grad_norm = grad_norm + + # check that grad norms are consistent across workers + self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() @@ -679,3 +634,56 @@ def _log_oom(self, exc): for device_idx in range(torch.cuda.device_count()): print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr) sys.stderr.flush() + + def _aggregate_logging_outputs(self, logging_outputs, *extra_stats_to_sum): + if self.get_criterion().__class__.logging_outputs_can_be_summed(): + return self._fast_stat_sync_sum(logging_outputs, *extra_stats_to_sum) + else: + return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum) + + def _all_gather_list_sync(self, logging_outputs, *extra_stats_to_sum): + """ + Sync logging outputs across workers. all_gather_list_sync is + suitable when logging outputs are complex types. + """ + results = list(zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.args, 'all_gather_list_size', 16384), + ) + )) + logging_outputs, extra_stats_to_sum = results[0], results[1:] + logging_outputs = list(chain.from_iterable(logging_outputs)) + extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] + return [logging_outputs] + extra_stats_to_sum + + def _fast_stat_sync_sum(self, logging_outputs, *extra_stats_to_sum): + """ + Sync logging outputs across workers. fast_stat_sync_sum is + faster than all_gather_list_sync, but is only suitable when + logging outputs are scalars and can be summed. + """ + sorted_keys = sorted(logging_outputs[0].keys()) + num_extra = len(extra_stats_to_sum) + stats = list(extra_stats_to_sum) + [ + sum(log.get(k, 0) for log in logging_outputs) + for k in sorted_keys + ] + buf = torch.cuda.DoubleTensor(stats) + distributed_utils.all_reduce(buf) + buf = buf.tolist() + extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:] + stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}] + return [stats] + extra_stats_to_sum + + def _check_grad_norms(self, grad_norm): + """Check that grad norms are consistent across workers.""" + if self._grad_norm_buf is not None: + self._grad_norm_buf.zero_() + self._grad_norm_buf[self.args.distributed_rank] = grad_norm + distributed_utils.all_reduce(self._grad_norm_buf) + if not (self._grad_norm_buf == self._grad_norm_buf[0]).all(): + raise RuntimeError( + "Fatal error: gradients are inconsistent between workers. " + "Try --ddp-backend=no_c10d." + )