Skip to content

Commit

Permalink
Merge internal changes (#654)
Browse files Browse the repository at this point in the history
Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: facebookresearch/fairseq#654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
  • Loading branch information
myleott authored and yzpang committed Feb 19, 2021
1 parent 534b343 commit 2ac15cc
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 48 deletions.
6 changes: 6 additions & 0 deletions docs/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
:members:
:undoc-members:

.. autoclass:: fairseq.optim.adadelta.Adadelta
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
Expand Down
3 changes: 2 additions & 1 deletion docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)

where the default implementation for ``train.train_step`` is roughly::
where the default implementation for ``task.train_step`` is roughly::

def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss

**Registering new plug-ins**

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial_classifying_names.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::

from fairseq import data, options, tasks, utils
from fairseq import checkpoint_utils, data, options, tasks

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
Expand All @@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::

# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference([args.path], task)
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]

while True:
Expand Down
21 changes: 15 additions & 6 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import numpy as np
import torch

from fairseq import options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module


class WordStat(object):
Expand Down Expand Up @@ -49,7 +48,7 @@ def __str__(self):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'

import_user_module(parsed_args)
utils.import_user_module(parsed_args)

print(parsed_args)

Expand All @@ -59,12 +58,17 @@ def main(parsed_args):

# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)

for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))

# reduce tokens per sample by the required context window size
Expand Down Expand Up @@ -151,6 +155,11 @@ def main(parsed_args):
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()

if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]

skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):
Expand Down
2 changes: 1 addition & 1 deletion examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Evaluate:
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/fconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)

if hasattr(args, 'max_target_positions'):
if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions

decoder = FConvDecoder(
Expand Down
8 changes: 8 additions & 0 deletions fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def add_args(parser):
help='initial learning rate during warmup phase; default is args.lr')
# fmt: on

@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
# fmt: on

def state_dict(self):
"""Return the LR scheduler state dict."""
return {
Expand Down
18 changes: 0 additions & 18 deletions fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,6 @@ def add_args(parser):
help='print sample generations during validation')
# fmt: on

@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()

task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model

def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
Expand Down
12 changes: 6 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

import torch

from fairseq import bleu, options, progress_bar, tasks, utils
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module


def main(args):
Expand All @@ -23,7 +22,7 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'

import_user_module(args)
utils.import_user_module(args)

if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
Expand All @@ -34,7 +33,6 @@ def main(args):
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

# Set dictionaries
try:
Expand All @@ -45,8 +43,10 @@ def main(args):

# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)

# Optimize ensemble for generation
Expand Down
11 changes: 6 additions & 5 deletions interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch

from fairseq import options, tasks, utils
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module

Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
Expand Down Expand Up @@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions):


def main(args):
import_user_module(args)
utils.import_user_module(args)

if args.buffer_size < 1:
args.buffer_size = 1
Expand All @@ -77,8 +76,10 @@ def main(args):

# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)

# Set dictionaries
Expand Down
5 changes: 2 additions & 3 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
from collections import Counter
from itertools import zip_longest

from fairseq import options, tasks
from fairseq import options, tasks, utils
from fairseq.data import indexed_dataset
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool

import os
import shutil


def main(args):
import_user_module(args)
utils.import_user_module(args)

print(args)

Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

import torch

from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module


def main(args, init_distributed=False):
import_user_module(args)
utils.import_user_module(args)

if args.max_tokens is None:
args.max_tokens = 6000
Expand Down Expand Up @@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):

if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)

if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
Expand Down

0 comments on commit 2ac15cc

Please sign in to comment.