Skip to content

Commit

Permalink
Support multi-GPU validation in fairseq-validate (#2162)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2162

Reviewed By: ngoyal2707

Differential Revision: D21663181

Pulled By: myleott

fbshipit-source-id: d01e64f97482f76bd601cd8b20232c0ef637bb8a
  • Loading branch information
myleott authored and facebook-github-bot committed May 27, 2020
1 parent be5313a commit 2f7e3f3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, task, sentence_avg):

@classmethod
def build_criterion(cls, args, task):
if args.ddp_backend == 'c10d':
if getattr(args, 'ddp_backend', None) == 'c10d':
raise Exception(
'AdaptiveLoss is not compatible with the c10d '
'version of DistributedDataParallel. Please use '
Expand Down
1 change: 1 addition & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_eval_lm_parser(default_task="language_modeling"):
def get_validation_parser(default_task=None):
parser = get_parser("Validation", default_task)
add_dataset_args(parser, train=True)
add_distributed_training_args(parser)
group = parser.add_argument_group("Evaluation")
add_common_eval_args(group)
return parser
Expand Down
17 changes: 14 additions & 3 deletions fairseq_cli/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import chain
import logging
import sys

import torch

from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.logging import metrics, progress_bar
from fairseq.options import add_distributed_training_args


logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
Expand All @@ -32,6 +33,9 @@ def main(args, override_args=None):
use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu

if use_cuda:
torch.cuda.set_device(args.device_id)

if override_args is not None:
overrides = vars(override_args)
overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
Expand Down Expand Up @@ -80,6 +84,8 @@ def main(args, override_args=None):
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
Expand All @@ -97,6 +103,13 @@ def main(args, override_args=None):
progress.log(log_output, step=i)
log_outputs.append(log_output)

if args.distributed_world_size > 1:
log_outputs = distributed_utils.all_gather_list(
log_outputs,
max_size=getattr(args, 'all_gather_list_size', 16384),
)
log_outputs = list(chain.from_iterable(log_outputs))

with metrics.aggregate() as agg:
task.reduce_metrics(log_outputs, criterion)
log_output = agg.get_smoothed_values()
Expand All @@ -106,12 +119,10 @@ def main(args, override_args=None):

def cli_main():
parser = options.get_validation_parser()
add_distributed_training_args(parser)
args = options.parse_args_and_arch(parser)

# only override args that are explicitly given on the command line
override_parser = options.get_validation_parser()
add_distributed_training_args(override_parser)
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)

distributed_utils.call_main(args, main, override_args=override_args)
Expand Down

0 comments on commit 2f7e3f3

Please sign in to comment.