From 46b773a393c423f653887c382e4d55e69627454d Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Wed, 4 Mar 2020 16:41:15 -0800 Subject: [PATCH] refactor namespaces in criterion interface (#1729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1672 in part (part 1: [context](https://github.com/pytorch/fairseq/pull/1714#issuecomment-587507040)) ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/1729 Differential Revision: D20049353 Pulled By: myleott fbshipit-source-id: 732077a1cc339c9f7ebe26dae42a7e8d7b5a07b4 --- .../criterions/cross_entropy_acc.py | 7 ++- fairseq/criterions/__init__.py | 2 +- fairseq/criterions/adaptive_loss.py | 10 +++- fairseq/criterions/binary_cross_entropy.py | 10 ++-- fairseq/criterions/composite_loss.py | 10 +++- fairseq/criterions/cross_entropy.py | 7 ++- fairseq/criterions/fairseq_criterion.py | 56 +++++++++++++++++-- .../label_smoothed_cross_entropy.py | 9 +-- ...l_smoothed_cross_entropy_with_alignment.py | 11 ++-- fairseq/criterions/legacy_masked_lm.py | 10 ++-- fairseq/criterions/nat_loss.py | 10 +++- fairseq/criterions/sentence_prediction.py | 4 ++ fairseq/criterions/sentence_ranking.py | 16 +++--- tests/speech_recognition/asr_test_base.py | 2 +- tests/test_label_smoothing.py | 12 ++-- 15 files changed, 123 insertions(+), 53 deletions(-) diff --git a/examples/speech_recognition/criterions/cross_entropy_acc.py b/examples/speech_recognition/criterions/cross_entropy_acc.py index f7b46a0aa9..7c4d8ba380 100644 --- a/examples/speech_recognition/criterions/cross_entropy_acc.py +++ b/examples/speech_recognition/criterions/cross_entropy_acc.py @@ -16,8 +16,9 @@ @register_criterion("cross_entropy_acc") class CrossEntropyWithAccCriterion(FairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg def compute_loss(self, model, net_output, target, reduction, log_probs): # N, T -> N * T @@ -50,7 +51,7 @@ def get_logging_output(self, sample, target, lprobs, loss): ) total = torch.sum(mask) sample_size = ( - sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) logging_output = { diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 618723aeff..1c28780111 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -7,7 +7,7 @@ import os from fairseq import registry -from fairseq.criterions.fairseq_criterion import FairseqCriterion +from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index c5c5329e20..33e9317e84 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -17,15 +17,19 @@ class AdaptiveLoss(FairseqCriterion): graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309).""" - def __init__(self, args, task): - super().__init__(args, task) + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + @classmethod + def build_criterion(cls, args, task): if args.ddp_backend == 'c10d': raise Exception( 'AdaptiveLoss is not compatible with the c10d ' 'version of DistributedDataParallel. Please use ' '`--ddp-backend=no_c10d` instead.' ) + return cls(task, args.sentence_avg) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -64,7 +68,7 @@ def forward(self, model, sample, reduce=True): orig = utils.strip_pad(orig_target, self.padding_idx) ntokens = orig.numel() - sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens + sample_size = sample['target'].size(0) if self.sentence_avg else ntokens logging_output = { 'loss': loss.data, 'ntokens': ntokens, diff --git a/fairseq/criterions/binary_cross_entropy.py b/fairseq/criterions/binary_cross_entropy.py index c7f28e4ecc..285f8ec773 100644 --- a/fairseq/criterions/binary_cross_entropy.py +++ b/fairseq/criterions/binary_cross_entropy.py @@ -15,11 +15,11 @@ @register_criterion('binary_cross_entropy') class BinaryCrossEntropyCriterion(FairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) - self.infonce = getattr(args, "infonce", False) - self.loss_weights = None if getattr(args, 'loss_weights', None) is None else eval(args.loss_weights) - self.log_keys = [] if getattr(args, 'log_keys', None) is None else eval(args.log_keys) + def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): + super().__init__(task) + self.infonce = infonce + self.loss_weights = None if loss_weights is None else eval(loss_weights) + self.log_keys = [] if log_keys is None else eval(log_keys) @staticmethod def add_args(parser): diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py index ce61723f79..6671c696e9 100644 --- a/fairseq/criterions/composite_loss.py +++ b/fairseq/criterions/composite_loss.py @@ -14,6 +14,10 @@ class CompositeLoss(FairseqCriterion): """This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair""" + def __init__(self, task, underlying_criterion): + super().__init__(task) + self.underlying_criterion = underlying_criterion + @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" @@ -58,8 +62,8 @@ def decoder(self): class _CompositeLoss(FairseqCriterion): - def __init__(self, args, task, underlying_criterion): - super().__init__(args, task) + def __init__(self, task, underlying_criterion): + super().__init__(task) self.underlying_criterion = underlying_criterion def forward(self, model, sample, reduce=True): @@ -92,4 +96,4 @@ def aggregate_logging_outputs(logging_outputs): def reduce_metrics(logging_outputs) -> None: underlying_criterion.__class__.reduce_metrics(logging_outputs) - return _CompositeLoss(args, task, underlying_criterion) + return _CompositeLoss(task, underlying_criterion) diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index dec9131c60..4de5fc5334 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -14,8 +14,9 @@ @register_criterion('cross_entropy') class CrossEntropyCriterion(FairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -27,7 +28,7 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample['net_input']) loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] logging_output = { 'loss': loss.data, 'ntokens': sample['ntokens'], diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 9d240b656e..9873574d47 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import inspect from typing import Any, Dict, List from torch.nn.modules.loss import _Loss @@ -12,11 +13,12 @@ class FairseqCriterion(_Loss): - def __init__(self, args, task): + def __init__(self, task): super().__init__() - self.args = args self.task = task - self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 + if hasattr(task, 'target_dictionary'): + tgt_dict = task.target_dictionary + self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 @staticmethod def add_args(parser): @@ -25,7 +27,35 @@ def add_args(parser): @classmethod def build_criterion(cls, args, task): - return cls(args, task) + """Construct a criterion from command-line args.""" + # Criterions can override this, but for convenience we also try + # to automatically map argparse.Namespace keys to corresponding + # arguments in the __init__. + init_args = {} + for p in inspect.signature(cls).parameters.values(): + if ( + p.kind == p.POSITIONAL_ONLY + or p.kind == p.VAR_POSITIONAL + or p.kind == p.VAR_KEYWORD + ): + # we haven't implemented inference for these argument types, + # but PRs welcome :) + raise NotImplementedError('{} not supported'.format(p.kind)) + + assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + + if p.name == 'task': + init_args['task'] = task + elif hasattr(args, p.name): + init_args[p.name] = getattr(args, p.name) + elif p.default != p.empty: + pass # we'll use the default value + else: + raise NotImplementedError( + 'Unable to infer Criterion arguments, please implement ' + '{}.build_criterion'.format(cls.__name__) + ) + return cls(**init_args) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -69,3 +99,21 @@ def logging_outputs_can_be_summed() -> bool: to True will improves distributed training speed. """ return False + + +class LegacyFairseqCriterion(FairseqCriterion): + + def __init__(self, args, task): + super().__init__(task=task) + self.args = args + + utils.deprecation_warning( + 'Criterions should take explicit arguments instead of an ' + 'argparse.Namespace object, please update your criterion by ' + 'extending FairseqCriterion instead of LegacyFairseqCriterion.' + ) + + @classmethod + def build_criterion(cls, args, task): + """Construct a criterion from command-line args.""" + return cls(args, task) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 6f1ef57cf4..b9587f6cca 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -33,9 +33,10 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T @register_criterion('label_smoothed_cross_entropy') class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) - self.eps = args.label_smoothing + def __init__(self, task, sentence_avg, label_smoothing): + super().__init__(task) + self.sentence_avg = sentence_avg + self.eps = label_smoothing @staticmethod def add_args(parser): @@ -55,7 +56,7 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample['net_input']) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] logging_output = { 'loss': loss.data, 'nll_loss': nll_loss.data, diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py index dc75175c79..cfc7e008cd 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -14,15 +14,14 @@ @register_criterion('label_smoothed_cross_entropy_with_alignment') class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion): - def __init__(self, args, task): - super().__init__(args, task) - self.alignment_lambda = args.alignment_lambda + def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): + super().__init__(task, sentence_avg, label_smoothing) + self.alignment_lambda = alignment_lambda @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" - super(LabelSmoothedCrossEntropyCriterionWithAlignment, - LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser) + LabelSmoothedCrossEntropyCriterion.add_args(parser) parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D', help='weight for the alignment loss') @@ -36,7 +35,7 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample['net_input']) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py index ef62e5050f..10dea76e4b 100644 --- a/fairseq/criterions/legacy_masked_lm.py +++ b/fairseq/criterions/legacy_masked_lm.py @@ -48,8 +48,10 @@ class LegacyMaskedLmLoss(FairseqCriterion): an argument. """ - def __init__(self, args, task): - super().__init__(args, task) + def __init__(self, task, masked_lm_only, nsp_loss_weight): + super().__init__(task) + self.masked_lm_only = masked_lm_only + self.nsp_loss_weight = nsp_loss_weight @staticmethod def add_args(parser): @@ -85,7 +87,7 @@ def forward(self, model, sample, reduce=True): # Compute sentence loss if masked_lm_only is False sentence_loss = None - if not self.args.masked_lm_only: + if not self.masked_lm_only: sentence_logits = output_metadata['sentence_logits'] sentence_targets = sample['sentence_target'].view(-1) # This needs to be recomputed due to some differences between @@ -102,7 +104,7 @@ def forward(self, model, sample, reduce=True): sentence_loss = compute_cross_entropy_loss( sentence_logits, sentence_targets) - loss += self.args.nsp_loss_weight * (sentence_loss / nsentences) + loss += self.nsp_loss_weight * (sentence_loss / nsentences) # NOTE: as we are summing up per token mlm loss and per sentence nsp loss # we don't need to use sample_size as denominator for the gradient diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py index d84cc4d623..3326734d55 100644 --- a/fairseq/criterions/nat_loss.py +++ b/fairseq/criterions/nat_loss.py @@ -15,17 +15,21 @@ @register_criterion("nat_loss") class LabelSmoothedDualImitationCriterion(FairseqCriterion): + + def __init__(self, task, label_smoothing): + super().__init__(task) + self.label_smoothing = label_smoothing + @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" - # fmt: off parser.add_argument( '--label-smoothing', default=0., type=float, metavar='D', - help='epsilon for label smoothing, 0 means no label smoothing') - # fmt: on + help='epsilon for label smoothing, 0 means no label smoothing', + ) def _compute_loss( self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py index 933c4e9bc8..8bdf204f57 100644 --- a/fairseq/criterions/sentence_prediction.py +++ b/fairseq/criterions/sentence_prediction.py @@ -15,6 +15,10 @@ @register_criterion('sentence_prediction') class SentencePredictionCriterion(FairseqCriterion): + def __init__(self, task, classification_head_name): + super().__init__(task) + self.classification_head_name = classification_head_name + @staticmethod def add_args(parser): # fmt: off diff --git a/fairseq/criterions/sentence_ranking.py b/fairseq/criterions/sentence_ranking.py index 86cae13cfe..13dcac9f9e 100644 --- a/fairseq/criterions/sentence_ranking.py +++ b/fairseq/criterions/sentence_ranking.py @@ -15,12 +15,14 @@ @register_criterion('sentence_ranking') class SentenceRankingCriterion(FairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) - if self.args.save_predictions is not None: - self.prediction_h = open(self.args.save_predictions, 'w') + def __init__(self, task, ranking_head_name, save_predictions, num_classes): + super().__init__(task) + self.ranking_head_name = ranking_head_name + if save_predictions is not None: + self.prediction_h = open(save_predictions, 'w') else: self.prediction_h = None + self.num_classes = num_classes def __del__(self): if self.prediction_h is not None: @@ -46,14 +48,14 @@ def forward(self, model, sample, reduce=True): """ assert ( hasattr(model, 'classification_heads') - and self.args.ranking_head_name in model.classification_heads + and self.ranking_head_name in model.classification_heads ), 'model must provide sentence ranking head for --criterion=sentence_ranking' scores = [] - for idx in range(self.args.num_classes): + for idx in range(self.num_classes): score, _ = model( **sample['net_input{idx}'.format(idx=idx+1)], - classification_head_name=self.args.ranking_head_name, + classification_head_name=self.ranking_head_name, ) scores.append(score) diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index ffd0133b99..7482858ffc 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -516,7 +516,7 @@ def setUpArgs(self): def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) - self.criterion = self.criterion_cls(args=args, task=DummyTask(args)) + self.criterion = self.criterion_cls.build_criterion(args=args, task=DummyTask(args)) def get_src_tokens(self, correct_prediction, aggregate): """ diff --git a/tests/test_label_smoothing.py b/tests/test_label_smoothing.py index 38a627c76c..8432d3c7bf 100644 --- a/tests/test_label_smoothing.py +++ b/tests/test_label_smoothing.py @@ -49,8 +49,8 @@ def setUp(self): def test_nll_loss(self): self.args.label_smoothing = 0.1 - nll_crit = CrossEntropyCriterion(self.args, self.task) - smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) + nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task) + smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) @@ -58,7 +58,7 @@ def test_nll_loss(self): def test_padding(self): self.args.label_smoothing = 0.1 - crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) + crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) loss, _, logging_output = crit(self.model, self.sample) def get_one_no_padding(idx): @@ -77,15 +77,15 @@ def get_one_no_padding(idx): def test_reduction(self): self.args.label_smoothing = 0.1 - crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) + crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) loss, _, logging_output = crit(self.model, self.sample, reduce=True) unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False) self.assertAlmostEqual(loss, unreduced_loss.sum()) def test_zero_eps(self): self.args.label_smoothing = 0.0 - nll_crit = CrossEntropyCriterion(self.args, self.task) - smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) + nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task) + smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) self.assertAlmostEqual(nll_loss, smooth_loss)