Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge internal changes #654

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
:members:
:undoc-members:

.. autoclass:: fairseq.optim.adadelta.Adadelta
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
Expand Down
3 changes: 2 additions & 1 deletion docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)

where the default implementation for ``train.train_step`` is roughly::
where the default implementation for ``task.train_step`` is roughly::

def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss

**Registering new plug-ins**

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial_classifying_names.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::

from fairseq import data, options, tasks, utils
from fairseq import checkpoint_utils, data, options, tasks

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
Expand All @@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::

# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference([args.path], task)
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]

while True:
Expand Down
21 changes: 15 additions & 6 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import numpy as np
import torch

from fairseq import options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module


class WordStat(object):
Expand Down Expand Up @@ -49,7 +48,7 @@ def __str__(self):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'

import_user_module(parsed_args)
utils.import_user_module(parsed_args)

print(parsed_args)

Expand All @@ -59,12 +58,17 @@ def main(parsed_args):

# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)

for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))

# reduce tokens per sample by the required context window size
Expand Down Expand Up @@ -151,6 +155,11 @@ def main(parsed_args):
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()

if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]

skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):
Expand Down
2 changes: 1 addition & 1 deletion examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d

# Evaluate:
Expand Down
176 changes: 176 additions & 0 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from collections import OrderedDict
import logging
import os
import traceback

import torch
from torch.serialization import default_restore_location

from fairseq import tasks


def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
state = _upgrade_state_dict(state)
return state


def load_model_ensemble(filenames, arg_overrides=None, task=None):
"""Loads an ensemble of models.

Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)

args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)

if task is None:
task = tasks.setup_task(args)

# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)

return ensemble, args


def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.

Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)

entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]


def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())


def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict


def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)


def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
9 changes: 6 additions & 3 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@

class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
Expand Down Expand Up @@ -143,6 +142,10 @@ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
self.symbols = list(new_symbols)
self.indices = new_indices

def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index

def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
Expand Down
12 changes: 11 additions & 1 deletion fairseq/data/monolingual_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset):
"""

def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None):
targets=None, add_bos_token=False):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
self.add_bos_token = add_bos_token

assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'"
Expand All @@ -83,6 +84,7 @@ def __getitem__(self, index):
else:
source = self.dataset[index]
target = None
source, target = self._maybe_add_bos(source, target)
return {'id': index, 'source': source, 'target': target}

def __len__(self):
Expand Down Expand Up @@ -121,6 +123,13 @@ def _make_source_target(self, source, future_target, past_target):

return source, self._filter_vocab(target)

def _maybe_add_bos(self, source, target):
if self.add_bos_token:
source = torch.cat([source.new([self.vocab.bos()]), source])
if target is not None:
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target

def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
Expand Down Expand Up @@ -165,6 +174,7 @@ def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
source, target = self._maybe_add_bos(source, target)

return self.collater([
{'id': i, 'source': source, 'target': target}
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/fconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)

if hasattr(args, 'max_target_positions'):
if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions

decoder = FConvDecoder(
Expand Down
Loading