Skip to content

Commit

Permalink
refactor namespaces in criterion interface (#1729)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes #1672 in part (part 1: [context](#1714 (comment)))

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: #1729

Differential Revision: D20049353

Pulled By: myleott

fbshipit-source-id: 732077a1cc339c9f7ebe26dae42a7e8d7b5a07b4
  • Loading branch information
erip authored and facebook-github-bot committed Mar 5, 2020
1 parent aa79bb9 commit 46b773a
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 53 deletions.
7 changes: 4 additions & 3 deletions examples/speech_recognition/criterions/cross_entropy_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
Expand Down Expand Up @@ -50,7 +51,7 @@ def get_logging_output(self, sample, target, lprobs, loss):
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)

logging_output = {
Expand Down
2 changes: 1 addition & 1 deletion fairseq/criterions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from fairseq import registry
from fairseq.criterions.fairseq_criterion import FairseqCriterion
from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion


build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
Expand Down
10 changes: 7 additions & 3 deletions fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ class AdaptiveLoss(FairseqCriterion):
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

@classmethod
def build_criterion(cls, args, task):
if args.ddp_backend == 'c10d':
raise Exception(
'AdaptiveLoss is not compatible with the c10d '
'version of DistributedDataParallel. Please use '
'`--ddp-backend=no_c10d` instead.'
)
return cls(task, args.sentence_avg)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down Expand Up @@ -64,7 +68,7 @@ def forward(self, model, sample, reduce=True):

orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens
sample_size = sample['target'].size(0) if self.sentence_avg else ntokens
logging_output = {
'loss': loss.data,
'ntokens': ntokens,
Expand Down
10 changes: 5 additions & 5 deletions fairseq/criterions/binary_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
@register_criterion('binary_cross_entropy')
class BinaryCrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.infonce = getattr(args, "infonce", False)
self.loss_weights = None if getattr(args, 'loss_weights', None) is None else eval(args.loss_weights)
self.log_keys = [] if getattr(args, 'log_keys', None) is None else eval(args.log_keys)
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = None if loss_weights is None else eval(loss_weights)
self.log_keys = [] if log_keys is None else eval(log_keys)

@staticmethod
def add_args(parser):
Expand Down
10 changes: 7 additions & 3 deletions fairseq/criterions/composite_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class CompositeLoss(FairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""

def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
Expand Down Expand Up @@ -58,8 +62,8 @@ def decoder(self):

class _CompositeLoss(FairseqCriterion):

def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion

def forward(self, model, sample, reduce=True):
Expand Down Expand Up @@ -92,4 +96,4 @@ def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)

return _CompositeLoss(args, task, underlying_criterion)
return _CompositeLoss(task, underlying_criterion)
7 changes: 4 additions & 3 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand All @@ -27,7 +28,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data,
'ntokens': sample['ntokens'],
Expand Down
56 changes: 52 additions & 4 deletions fairseq/criterions/fairseq_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from typing import Any, Dict, List

from torch.nn.modules.loss import _Loss
Expand All @@ -12,11 +13,12 @@

class FairseqCriterion(_Loss):

def __init__(self, args, task):
def __init__(self, task):
super().__init__()
self.args = args
self.task = task
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
if hasattr(task, 'target_dictionary'):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100

@staticmethod
def add_args(parser):
Expand All @@ -25,7 +27,35 @@ def add_args(parser):

@classmethod
def build_criterion(cls, args, task):
return cls(args, task)
"""Construct a criterion from command-line args."""
# Criterions can override this, but for convenience we also try
# to automatically map argparse.Namespace keys to corresponding
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError('{} not supported'.format(p.kind))

assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}

if p.name == 'task':
init_args['task'] = task
elif hasattr(args, p.name):
init_args[p.name] = getattr(args, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
'Unable to infer Criterion arguments, please implement '
'{}.build_criterion'.format(cls.__name__)
)
return cls(**init_args)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down Expand Up @@ -69,3 +99,21 @@ def logging_outputs_can_be_summed() -> bool:
to True will improves distributed training speed.
"""
return False


class LegacyFairseqCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(task=task)
self.args = args

utils.deprecation_warning(
'Criterions should take explicit arguments instead of an '
'argparse.Namespace object, please update your criterion by '
'extending FairseqCriterion instead of LegacyFairseqCriterion.'
)

@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
return cls(args, task)
9 changes: 5 additions & 4 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
def __init__(self, task, sentence_avg, label_smoothing):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing

@staticmethod
def add_args(parser):
Expand All @@ -55,7 +56,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data,
'nll_loss': nll_loss.data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.alignment_lambda = args.alignment_lambda
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
LabelSmoothedCrossEntropyCriterion.add_args(parser)
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
help='weight for the alignment loss')

Expand All @@ -36,7 +35,7 @@ def forward(self, model, sample, reduce=True):
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
Expand Down
10 changes: 6 additions & 4 deletions fairseq/criterions/legacy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class LegacyMaskedLmLoss(FairseqCriterion):
an argument.
"""

def __init__(self, args, task):
super().__init__(args, task)
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight

@staticmethod
def add_args(parser):
Expand Down Expand Up @@ -85,7 +87,7 @@ def forward(self, model, sample, reduce=True):

# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.args.masked_lm_only:
if not self.masked_lm_only:
sentence_logits = output_metadata['sentence_logits']
sentence_targets = sample['sentence_target'].view(-1)
# This needs to be recomputed due to some differences between
Expand All @@ -102,7 +104,7 @@ def forward(self, model, sample, reduce=True):
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets)

loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)

# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
Expand Down
10 changes: 7 additions & 3 deletions fairseq/criterions/nat_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@

@register_criterion("nat_loss")
class LabelSmoothedDualImitationCriterion(FairseqCriterion):

def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument(
'--label-smoothing',
default=0.,
type=float,
metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
# fmt: on
help='epsilon for label smoothing, 0 means no label smoothing',
)

def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
Expand Down
4 changes: 4 additions & 0 deletions fairseq/criterions/sentence_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
@register_criterion('sentence_prediction')
class SentencePredictionCriterion(FairseqCriterion):

def __init__(self, task, classification_head_name):
super().__init__(task)
self.classification_head_name = classification_head_name

@staticmethod
def add_args(parser):
# fmt: off
Expand Down
16 changes: 9 additions & 7 deletions fairseq/criterions/sentence_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
@register_criterion('sentence_ranking')
class SentenceRankingCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w')
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, 'w')
else:
self.prediction_h = None
self.num_classes = num_classes

def __del__(self):
if self.prediction_h is not None:
Expand All @@ -46,14 +48,14 @@ def forward(self, model, sample, reduce=True):
"""
assert (
hasattr(model, 'classification_heads')
and self.args.ranking_head_name in model.classification_heads
and self.ranking_head_name in model.classification_heads
), 'model must provide sentence ranking head for --criterion=sentence_ranking'

scores = []
for idx in range(self.args.num_classes):
for idx in range(self.num_classes):
score, _ = model(
**sample['net_input{idx}'.format(idx=idx+1)],
classification_head_name=self.args.ranking_head_name,
classification_head_name=self.ranking_head_name,
)
scores.append(score)

Expand Down
2 changes: 1 addition & 1 deletion tests/speech_recognition/asr_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def setUpArgs(self):
def setUp(self):
args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder())
self.criterion = self.criterion_cls(args=args, task=DummyTask(args))
self.criterion = self.criterion_cls.build_criterion(args=args, task=DummyTask(args))

def get_src_tokens(self, correct_prediction, aggregate):
"""
Expand Down
Loading

0 comments on commit 46b773a

Please sign in to comment.