Skip to content

Commit

Permalink
Parametrize fit hook test with different precision plugins (#8070)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli committed Jul 5, 2021
1 parent 7b6d0a8 commit ea88105
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,14 @@ def backward(
) -> Tensor:
if is_overridden('backward', model):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()

return closure_loss

def clip_gradients(
Expand Down
72 changes: 52 additions & 20 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def call(hook, fn, *args, **kwargs):
d = {'name': hook}
if args:
d['args'] = args
elif hook == 'train':
# DeepSpeed calls `train(mode)` but we do not. Standardize
# https://github.com/microsoft/DeepSpeed/pull/571
d['args'] = (True, )
if kwargs:
d['kwargs'] = kwargs
called.append(d)
Expand All @@ -283,12 +287,13 @@ def test_epoch_end(self, *args, **kwargs):
pass

@staticmethod
def _train_batch(trainer, model, batches, current_epoch=0):
def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_epoch=0, **kwargs):
using_native_amp = kwargs.get('amp_backend') == 'native'
out = []
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='transfer_batch_to_device', args=(ANY, device, 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `on_batch_{start,end}`
dict(name='Callback.on_batch_start', args=(trainer, model)),
Expand All @@ -301,14 +306,15 @@ def _train_batch(trainer, model, batches, current_epoch=0):
dict(name='on_before_zero_grad', args=(ANY, )),
dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)),
# TODO: `on_before_backward`
dict(name='backward', args=(ANY, ANY, 0)),
# DeepSpeed handles backward internally
*([dict(name='backward', args=(ANY, ANY, 0))] if kwargs.get('plugins') != 'deepspeed' else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
# TODO: `on_before_optimizer_step`
dict(
name='optimizer_step',
args=(current_epoch, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False)
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp)
),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)),
Expand All @@ -317,14 +323,14 @@ def _train_batch(trainer, model, batches, current_epoch=0):
return out

@staticmethod
def _eval_epoch(fn, trainer, model, batches, key):
def _eval_epoch(fn, trainer, model, batches, key, device=torch.device('cpu')):
outputs = {key: ANY}
return [
dict(name='Callback.on_epoch_start', args=(trainer, model)),
dict(name='on_epoch_start'),
dict(name=f'Callback.on_{fn}_epoch_start', args=(trainer, model)),
dict(name=f'on_{fn}_epoch_start'),
*HookedModel._eval_batch(fn, trainer, model, batches, key),
*HookedModel._eval_batch(fn, trainer, model, batches, key, device=device),
dict(name=f'{fn}_epoch_end', args=([outputs] * batches, )),
dict(name=f'Callback.on_{fn}_epoch_end', args=(trainer, model)),
dict(name=f'on_{fn}_epoch_end'),
Expand All @@ -333,13 +339,13 @@ def _eval_epoch(fn, trainer, model, batches, key):
]

@staticmethod
def _eval_batch(fn, trainer, model, batches, key):
def _eval_batch(fn, trainer, model, batches, key, device=torch.device('cpu')):
out = []
outputs = {key: ANY}
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='transfer_batch_to_device', args=(ANY, device, 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)),
Expand All @@ -357,10 +363,10 @@ def _predict_batch(trainer, model, batches):
out = []
for i in range(batches):
out.extend([
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
dict(name='forward', args=(ANY, )),
Expand All @@ -372,7 +378,17 @@ def _predict_batch(trainer, model, batches):
return out


def test_trainer_model_hook_system_fit(tmpdir):
@pytest.mark.parametrize(
'kwargs',
[
{},
# these precision plugins modify the optimization flow, so testing them explicitly
pytest.param(dict(gpus=1, precision=16, plugins='deepspeed'), marks=RunIf(deepspeed=True, min_gpus=1)),
pytest.param(dict(gpus=1, precision=16, amp_backend='native'), marks=RunIf(amp_native=True, min_gpus=1)),
pytest.param(dict(gpus=1, precision=16, amp_backend='apex'), marks=RunIf(amp_apex=True, min_gpus=1)),
]
)
def test_trainer_model_hook_system_fit(tmpdir, kwargs):
called = []
model = HookedModel(called)
callback = HookedCallback(called)
Expand All @@ -385,13 +401,17 @@ def test_trainer_model_hook_system_fit(tmpdir):
limit_val_batches=val_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
callbacks=[callback]
callbacks=[callback],
**kwargs,
)

assert called == [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
]

trainer.fit(model)

saved_ckpt = {
'callbacks': ANY,
'epoch': 1,
Expand All @@ -401,19 +421,31 @@ def test_trainer_model_hook_system_fit(tmpdir):
'pytorch-lightning_version': __version__,
'state_dict': ANY,
}
if kwargs.get('amp_backend') == 'native':
saved_ckpt['native_amp_scaling_state'] = ANY
elif kwargs.get('amp_backend') == 'apex':
saved_ckpt['amp_scaling_state'] = ANY
device = torch.device('cuda:0' if 'gpus' in kwargs else 'cpu')

expected = [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
dict(name='prepare_data'),
dict(name='configure_callbacks'),
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
# DeepSpeed needs the batch size to figure out throughput logging
*([dict(name='train_dataloader')] if kwargs.get('plugins') == 'deepspeed' else []),
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')),
dict(name='setup', kwargs=dict(stage='fit')),
dict(name='configure_sharded_model'),
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
dict(name='configure_optimizers'),
# DeepSpeed skips initializing optimizers here as they are handled via config
*([dict(name='configure_optimizers')] if kwargs.get('plugins') != 'deepspeed' else []),
dict(name='Callback.on_fit_start', args=(trainer, model)),
dict(name='on_fit_start'),
# TODO: explore whether DeepSpeed can have the same flow for optimizers
# DeepSpeed did not find any optimizer in the config so they are loaded here
*([dict(name='configure_optimizers')] if kwargs.get('plugins') == 'deepspeed' else []),
dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
dict(name='on_pretrain_routine_start'),
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
Expand All @@ -426,14 +458,14 @@ def test_trainer_model_hook_system_fit(tmpdir):
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device),
dict(name='Callback.on_validation_end', args=(trainer, model)),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='train', args=(True, )),
dict(name='on_validation_model_train'),
dict(name='Callback.on_sanity_check_end', args=(trainer, model)),
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
dict(name='train'),
dict(name='train', args=(True, )),
dict(name='on_train_dataloader'),
dict(name='train_dataloader'),
dict(name='Callback.on_train_start', args=(trainer, model)),
Expand All @@ -442,19 +474,19 @@ def test_trainer_model_hook_system_fit(tmpdir):
dict(name='on_epoch_start'),
dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
dict(name='on_train_epoch_start'),
*model._train_batch(trainer, model, train_batches),
*model._train_batch(trainer, model, train_batches, device=device, **kwargs),
dict(name='train', args=(False, )),
dict(name='on_validation_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device),
dict(name='Callback.on_validation_end', args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='train', args=(True, )),
dict(name='on_validation_model_train'),
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
Expand Down Expand Up @@ -542,7 +574,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name='on_pretrain_routine_start'),
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
dict(name='on_pretrain_routine_end'),
dict(name='train'),
dict(name='train', args=(True, )),
dict(name='on_train_dataloader'),
dict(name='train_dataloader'),
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
Expand Down Expand Up @@ -610,7 +642,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
*model._eval_epoch(noun, trainer, model, batches, key),
dict(name=f'Callback.on_{noun}_end', args=(trainer, model)),
dict(name=f'on_{noun}_end'),
dict(name='train'),
dict(name='train', args=(True, )),
dict(name=f'on_{noun}_model_train'),
]
expected = [
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
gpus=1,
precision=16,
)
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
with pytest.warns(UserWarning, match='will be ignored since DeepSpeed handles the backward'):
trainer.fit(model)


Expand Down

0 comments on commit ea88105

Please sign in to comment.