Skip to content

Commit

Permalink
Parameterized criterions (#808)
Browse files Browse the repository at this point in the history
Summary:
Support criterion with parameters, such as AutoSegmentationCriterion (ASG) used in wav2letter which has a transition matrix parameter. This is needed to integrate wav2letter's ASG into PySpeech.

With this diff, parameters in criterions will be:
(1) updated by optimizers, with a configurable learning rate
(2) saved and loaded from checkpoints, preserving backward compatibility for criterions without parameters
(3) synchronized across nodes in distributed training.
Pull Request resolved: fairinternal/fairseq-py#808

Reviewed By: jcai1

Differential Revision: D16934097

Pulled By: okhonko

fbshipit-source-id: 121ec9382459385c6f9cbef3a8274bec1a434038
  • Loading branch information
0xjc authored and facebook-github-bot committed Aug 21, 2019
1 parent a2f5361 commit ba5f829
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 39 deletions.
3 changes: 3 additions & 0 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
from fairseq import utils
if optim_history is None:
optim_history = []
if extra_state is None:
Expand All @@ -239,6 +240,8 @@ def save_state(
],
'extra_state': extra_state,
}
if utils.has_parameters(criterion):
state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename)
Expand Down
6 changes: 3 additions & 3 deletions fairseq/models/distributed_fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import inspect

from torch.nn import parallel
import torch.nn as nn

from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from fairseq.models import BaseFairseqModel
Expand All @@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model):
model (BaseFairseqModel): model to wrap
"""
# determine which DDP class to extend
assert isinstance(model, BaseFairseqModel)
assert isinstance(model, nn.Module)
if args.ddp_backend == 'c10d':
ddp_class = parallel.DistributedDataParallel
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
Expand Down
7 changes: 1 addition & 6 deletions fairseq/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@
]


_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
'--optimizer',
base_class=FairseqOptimizer,
default='nag',
)


def build_optimizer(args, params, *extra_args, **extra_kwargs):
params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(args, params, *extra_args, **extra_kwargs)


# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@register_optimizer('adadelta')
class Adadelta(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@register_optimizer('adafactor')
class FairseqAdafactor(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = Adafactor(params, **self.optimizer_config)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class FairseqAdam(FairseqOptimizer):

def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
if torch.cuda.is_available():
try:
from apex.optimizers import FusedAdam as _FusedAdam # noqa
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@register_optimizer('adamax')
class FairseqAdamax(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config)

@staticmethod
Expand Down
5 changes: 2 additions & 3 deletions fairseq/optim/bmuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer):
model-update filtering
"""

def __init__(self, args, params, optimizer):
def __init__(self, args, optimizer):

super().__init__(args, params)
super().__init__(args)
self._optimizer = optimizer
self.params = params
self._num_updates = 0
self.sync_iter = self.args.global_sync_iter
self.block_momentum = self.args.block_momentum
Expand Down
15 changes: 10 additions & 5 deletions fairseq/optim/fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

class FairseqOptimizer(object):

def __init__(self, args, params):
def __init__(self, args):
super().__init__()
self.args = args
self.params = list(params)

@staticmethod
def add_args(parser):
Expand All @@ -39,6 +38,13 @@ def optimizer_config(self):
"""
raise NotImplementedError

@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
yield p

def __getstate__(self):
return self._optimizer.__getstate__()

Expand Down Expand Up @@ -93,9 +99,8 @@ def step(self, closure=None):

def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
for p in self.params:
p.grad = None
self.optimizer.zero_grad()

@property
Expand Down
11 changes: 6 additions & 5 deletions fairseq/optim/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class FP16Optimizer(optim.FairseqOptimizer):
"""

def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params

Expand Down Expand Up @@ -149,7 +150,7 @@ def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
for p in self.fp16_params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
Expand Down Expand Up @@ -196,7 +197,7 @@ def step(self, closure=None):

# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
for p in self.fp16_params:
if not p.requires_grad:
continue
numel = p.data.numel()
Expand All @@ -205,7 +206,7 @@ def step(self, closure=None):

def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for p in self.params:
for p in self.fp16_params:
p.grad = None
self._needs_sync = False

Expand All @@ -232,7 +233,7 @@ def __init__(self, args, params, optimizer):
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
)

super().__init__(args, params)
super().__init__(args)
self.wrapped_optimizer = optimizer

if getattr(args, 'fp16_scale_window', None) is None:
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/nag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@register_optimizer('nag')
class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = NAG(params, **self.optimizer_config)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@register_optimizer('sgd')
class SGD(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)

@staticmethod
Expand Down
49 changes: 39 additions & 10 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
self.task = task

# copy model and criterion to current device
self.criterion = criterion
self._criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
if self.cuda:
self.criterion = self.criterion.cuda()
self._criterion = self._criterion.cuda()
self._model = self._model.cuda()

self._dummy_batch = dummy_batch
Expand All @@ -53,6 +54,7 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
self._optim_history = None
self._optimizer = None
self._prev_grad_norm = None
self._wrapped_criterion = None
self._wrapped_model = None

self.init_meters(args)
Expand All @@ -75,6 +77,21 @@ def init_meters(self, args):
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds

@property
def criterion(self):
if self._wrapped_criterion is None:
if (
utils.has_parameters(self._criterion)
and self.args.distributed_world_size > 1
and not self.args.use_bmuf
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.args, self._criterion
)
else:
self._wrapped_criterion = self._criterion
return self._wrapped_criterion

@property
def model(self):
if self._wrapped_model is None:
Expand All @@ -99,7 +116,13 @@ def lr_scheduler(self):
return self._lr_scheduler

def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.criterion.parameters()),
)
)

if self.args.fp16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
Expand All @@ -114,7 +137,7 @@ def _build_optimizer(self):
self._optimizer = optim.build_optimizer(self.args, params)

if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer)
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)

# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
Expand All @@ -126,7 +149,7 @@ def save_checkpoint(self, filename, extra_state):
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion,
filename, self.args, self.get_model().state_dict(), self.get_criterion(),
self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state,
)
Expand All @@ -148,6 +171,8 @@ def load_checkpoint(
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint {}; '
Expand All @@ -164,7 +189,7 @@ def load_checkpoint(

# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
Expand Down Expand Up @@ -322,9 +347,9 @@ def maybe_no_sync():

# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.criterion)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())

if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
Expand Down Expand Up @@ -424,10 +449,10 @@ def valid_step(self, sample, raise_oom=False):

# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_output, self.criterion
logging_output, self.get_criterion()
)
sample_size = self.task.grad_denom(
sample_size, self.criterion
sample_size, self.get_criterion()
)

# update meters for validation
Expand Down Expand Up @@ -477,6 +502,10 @@ def get_model(self):
"""Get the (non-wrapped) model instance."""
return self._model

def get_criterion(self):
"""Get the (non-wrapped) criterion instance."""
return self._criterion

def get_meter(self, name):
"""Get a specific meter by name."""
if name not in self.meters:
Expand Down
8 changes: 8 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,11 @@ def eval(model):
model.eval()
yield
model.train(is_training)


def has_parameters(module):
try:
next(module.parameters())
return True
except StopIteration:
return False

0 comments on commit ba5f829

Please sign in to comment.