-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: - Add --add-bos-token option to LM task - Cleanup utils.py and options.py Pull Request resolved: #654 Differential Revision: D15041794 Pulled By: myleott fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
- Loading branch information
1 parent
89a6961
commit d45db80
Showing
34 changed files
with
368 additions
and
278 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
from collections import OrderedDict | ||
import logging | ||
import os | ||
import re | ||
import traceback | ||
|
||
import torch | ||
from torch.serialization import default_restore_location | ||
|
||
from fairseq import tasks | ||
|
||
|
||
def load_checkpoint_to_cpu(path): | ||
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" | ||
state = torch.load( | ||
path, map_location=lambda s, l: default_restore_location(s, 'cpu'), | ||
) | ||
state = _upgrade_state_dict(state) | ||
return state | ||
|
||
|
||
def load_model_ensemble(filenames, arg_overrides=None, task=None): | ||
"""Loads an ensemble of models. | ||
Args: | ||
filenames (List[str]): checkpoint files to load | ||
arg_overrides (Dict[str,Any], optional): override model args that | ||
were used during model training | ||
task (fairseq.tasks.FairseqTask, optional): task to use for loading | ||
""" | ||
ensemble = [] | ||
for filename in filenames: | ||
if not os.path.exists(filename): | ||
raise IOError('Model file not found: {}'.format(filename)) | ||
state = load_checkpoint_to_cpu(filename) | ||
|
||
args = state['args'] | ||
if arg_overrides is not None: | ||
for arg_name, arg_val in arg_overrides.items(): | ||
setattr(args, arg_name, arg_val) | ||
|
||
if task is None: | ||
task = tasks.setup_task(args) | ||
|
||
# build model for ensemble | ||
model = task.build_model(args) | ||
model.load_state_dict(state['model'], strict=True) | ||
ensemble.append(model) | ||
|
||
return ensemble, args | ||
|
||
|
||
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): | ||
"""Retrieves all checkpoints found in `path` directory. | ||
Checkpoints are identified by matching filename to the specified pattern. If | ||
the pattern contains groups, the result will be sorted by the first group in | ||
descending order. | ||
""" | ||
pt_regexp = re.compile(pattern) | ||
files = os.listdir(path) | ||
|
||
entries = [] | ||
for i, f in enumerate(files): | ||
m = pt_regexp.fullmatch(f) | ||
if m is not None: | ||
idx = int(m.group(1)) if len(m.groups()) > 0 else i | ||
entries.append((idx, m.group(0))) | ||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] | ||
|
||
|
||
def torch_persistent_save(*args, **kwargs): | ||
for i in range(3): | ||
try: | ||
return torch.save(*args, **kwargs) | ||
except Exception: | ||
if i == 2: | ||
logging.error(traceback.format_exc()) | ||
|
||
|
||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): | ||
if isinstance(state_dict, dict): | ||
cpu_dict = OrderedDict() | ||
for k, v in state_dict.items(): | ||
cpu_dict[k] = convert_state_dict_type(v) | ||
return cpu_dict | ||
elif isinstance(state_dict, list): | ||
return [convert_state_dict_type(v) for v in state_dict] | ||
elif torch.is_tensor(state_dict): | ||
return state_dict.type(ttype) | ||
else: | ||
return state_dict | ||
|
||
|
||
def save_state( | ||
filename, args, model_state_dict, criterion, optimizer, lr_scheduler, | ||
num_updates, optim_history=None, extra_state=None, | ||
): | ||
if optim_history is None: | ||
optim_history = [] | ||
if extra_state is None: | ||
extra_state = {} | ||
state_dict = { | ||
'args': args, | ||
'model': model_state_dict if model_state_dict else {}, | ||
'optimizer_history': optim_history + [ | ||
{ | ||
'criterion_name': criterion.__class__.__name__, | ||
'optimizer_name': optimizer.__class__.__name__, | ||
'lr_scheduler_state': lr_scheduler.state_dict(), | ||
'num_updates': num_updates, | ||
} | ||
], | ||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()), | ||
'extra_state': extra_state, | ||
} | ||
torch_persistent_save(state_dict, filename) | ||
|
||
|
||
def _upgrade_state_dict(state): | ||
"""Helper for upgrading old model checkpoints.""" | ||
# add optimizer_history | ||
if 'optimizer_history' not in state: | ||
state['optimizer_history'] = [ | ||
{ | ||
'criterion_name': 'CrossEntropyCriterion', | ||
'best_loss': state['best_loss'], | ||
}, | ||
] | ||
state['last_optimizer_state'] = state['optimizer'] | ||
del state['optimizer'] | ||
del state['best_loss'] | ||
# move extra_state into sub-dictionary | ||
if 'epoch' in state and 'extra_state' not in state: | ||
state['extra_state'] = { | ||
'epoch': state['epoch'], | ||
'batch_offset': state['batch_offset'], | ||
'val_loss': state['val_loss'], | ||
} | ||
del state['epoch'] | ||
del state['batch_offset'] | ||
del state['val_loss'] | ||
# reduce optimizer history's memory usage (only keep the last state) | ||
if 'optimizer' in state['optimizer_history'][-1]: | ||
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] | ||
for optim_hist in state['optimizer_history']: | ||
del optim_hist['optimizer'] | ||
# record the optimizer class name | ||
if 'optimizer_name' not in state['optimizer_history'][-1]: | ||
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' | ||
# move best_loss into lr_scheduler_state | ||
if 'lr_scheduler_state' not in state['optimizer_history'][-1]: | ||
state['optimizer_history'][-1]['lr_scheduler_state'] = { | ||
'best': state['optimizer_history'][-1]['best_loss'], | ||
} | ||
del state['optimizer_history'][-1]['best_loss'] | ||
# keep track of number of updates | ||
if 'num_updates' not in state['optimizer_history'][-1]: | ||
state['optimizer_history'][-1]['num_updates'] = 0 | ||
# old model checkpoints may not have separate source/target positions | ||
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): | ||
state['args'].max_source_positions = state['args'].max_positions | ||
state['args'].max_target_positions = state['args'].max_positions | ||
# use stateful training data iterator | ||
if 'train_iterator' not in state['extra_state']: | ||
state['extra_state']['train_iterator'] = { | ||
'epoch': state['extra_state']['epoch'], | ||
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), | ||
} | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.