Skip to content

Commit

Permalink
Deprecate --fast-stat-sync and replace with Criterion.logging_outputs…
Browse files Browse the repository at this point in the history
…_can_be_summed

Summary: Pull Request resolved: fairinternal/fairseq-py#980

Differential Revision: D19351116

Pulled By: myleott

fbshipit-source-id: a67b10637f53a80c37b0ce90eb27ced9709871db
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 11, 2020
1 parent c9a7c06 commit fe6c2ed
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 70 deletions.
9 changes: 9 additions & 0 deletions fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ def aggregate_logging_outputs(logging_outputs):
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,12 @@ def aggregate_logging_outputs(logging_outputs):
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/fairseq_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ def forward(self, model, sample, reduce=True):
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return False
9 changes: 9 additions & 0 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs):
'nsentences': nsentences,
'sample_size': sample_size,
}

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs):
'nsentences': nsentences,
'sample_size': sample_size,
}

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/legacy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,12 @@ def aggregate_logging_outputs(logging_outputs):
'sample_size': sample_size,
}
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,12 @@ def aggregate_logging_outputs(logging_outputs):
'sample_size': sample_size,
}
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/nat_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,12 @@ def aggregate_logging_outputs(logging_outputs):
)

return results

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/sentence_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,12 @@ def aggregate_logging_outputs(logging_outputs):
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
9 changes: 9 additions & 0 deletions fairseq/criterions/sentence_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,12 @@ def aggregate_logging_outputs(logging_outputs):
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return True
3 changes: 1 addition & 2 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def add_distributed_training_args(parser):
help='disable unused parameter detection (not applicable to '
'no_c10d ddp-backend')
group.add_argument('--fast-stat-sync', default=False, action='store_true',
help='Enable fast sync of stats between nodes, this hardcodes to '
'sync only some default stats from logging_output.')
help='[deprecated] this is now defined per Criterion')
# fmt: on
return group

Expand Down
144 changes: 76 additions & 68 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,19 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._prev_grad_norm = None
self._wrapped_criterion = None
self._wrapped_model = None

# Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
# It is less flexible and syncs only the default stats.
self._all_reduce_list = [0.0] * 6
self.fast_stat_sync = args.fast_stat_sync
if self.cuda and args.distributed_world_size > 1:
self._grad_norm_buf = torch.cuda.DoubleTensor(args.distributed_world_size)
else:
self._grad_norm_buf = None

if args.fast_stat_sync:
utils.deprecation_warning(
'--fast-stat-sync is deprecated. If needed, please update your '
'Criterion to define the logging_outputs_can_be_summed() method.'
)

self.init_meters(args)

Expand Down Expand Up @@ -294,7 +299,7 @@ def train_step(self, samples, dummy_batch=False, raise_oom=False):
self.meters["train_wall"].start()

# forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
Expand Down Expand Up @@ -323,22 +328,13 @@ def maybe_no_sync():
try:
with maybe_no_sync():
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
loss, sample_size_i, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer, ignore_grad
)

if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)

if self.fast_stat_sync:
self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get(
"nsentences", 0.0
)
self._all_reduce_list[2] += logging_output.get("loss", 0.0)
self._all_reduce_list[3] += logging_output.get("nll_loss", 0.0)
self._all_reduce_list[4] += logging_output.get("ntokens", 0.0)
sample_size += sample_size_i
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
Expand All @@ -353,71 +349,28 @@ def maybe_no_sync():
else:
raise e

if self.fast_stat_sync:
self._all_reduce_list[5] += ooms

if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)

if dummy_batch:
return None

# gather logging outputs from all replicas
if self.fast_stat_sync:
# rework all_gather_list
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
if self._sync_stats():
torch.distributed.all_reduce(all_reduce_list_tensor)
# Normalize loss and nll_loss by "sample_size"
# and convert to log base 2
all_reduce_list_tensor[2:4].div_(
(all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))
if self._sync_stats():
logging_outputs, sample_size, ooms = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms,
)
self._all_reduce_list = all_reduce_list_tensor.tolist()
logging_output = {}
[
sample_size,
logging_output["nsentences"],
logging_output["loss"],
logging_output["nll_loss"],
logging_output["ntokens"],
ooms,
] = self._all_reduce_list
elif self._sync_stats():
logging_outputs, sample_sizes, ooms, prev_norms = zip(
*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
max_size=getattr(self.args, 'all_gather_list_size', 16384),
)
)
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)

if not self.args.use_bmuf:
norms = [norm for norm in prev_norms if norm is not None]
if not (
all(norm == norms[0] for norm in norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in norms)
):
raise RuntimeError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=no_c10d, which is a more robust but "
"slightly slower DDP implementation."
)

self.meters["oom"].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
print("| WARNING: OOM in all workers, skipping update")
self.zero_grad()
return None

if not self.fast_stat_sync:
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = sum(sample_sizes)
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)

if not all(k in logging_output for k in ["ntokens", "nsentences"]):
raise Exception(
Expand All @@ -442,7 +395,9 @@ def maybe_no_sync():

# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
self._prev_grad_norm = grad_norm

# check that grad norms are consistent across workers
self._check_grad_norms(grad_norm)

# take an optimization step
self.optimizer.step()
Expand Down Expand Up @@ -679,3 +634,56 @@ def _log_oom(self, exc):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr)
sys.stderr.flush()

def _aggregate_logging_outputs(self, logging_outputs, *extra_stats_to_sum):
if self.get_criterion().__class__.logging_outputs_can_be_summed():
return self._fast_stat_sync_sum(logging_outputs, *extra_stats_to_sum)
else:
return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum)

def _all_gather_list_sync(self, logging_outputs, *extra_stats_to_sum):
"""
Sync logging outputs across workers. all_gather_list_sync is
suitable when logging outputs are complex types.
"""
results = list(zip(
*distributed_utils.all_gather_list(
[logging_outputs] + list(extra_stats_to_sum),
max_size=getattr(self.args, 'all_gather_list_size', 16384),
)
))
logging_outputs, extra_stats_to_sum = results[0], results[1:]
logging_outputs = list(chain.from_iterable(logging_outputs))
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
return [logging_outputs] + extra_stats_to_sum

def _fast_stat_sync_sum(self, logging_outputs, *extra_stats_to_sum):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
faster than all_gather_list_sync, but is only suitable when
logging outputs are scalars and can be summed.
"""
sorted_keys = sorted(logging_outputs[0].keys())
num_extra = len(extra_stats_to_sum)
stats = list(extra_stats_to_sum) + [
sum(log.get(k, 0) for log in logging_outputs)
for k in sorted_keys
]
buf = torch.cuda.DoubleTensor(stats)
distributed_utils.all_reduce(buf)
buf = buf.tolist()
extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:]
stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
return [stats] + extra_stats_to_sum

def _check_grad_norms(self, grad_norm):
"""Check that grad norms are consistent across workers."""
if self._grad_norm_buf is not None:
self._grad_norm_buf.zero_()
self._grad_norm_buf[self.args.distributed_rank] = grad_norm
distributed_utils.all_reduce(self._grad_norm_buf)
if not (self._grad_norm_buf == self._grad_norm_buf[0]).all():
raise RuntimeError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=no_c10d."
)

0 comments on commit fe6c2ed

Please sign in to comment.