diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index f05fd4d54b811..3edcc6866e219 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -64,15 +64,14 @@ def backward( ) -> Tensor: if is_overridden('backward', model): warning_cache.warn( - "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" - "backward logic outside of the LightningModule" + "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" + " the backward logic internally." ) # todo: hack around for deepspeed engine to call backward deepspeed_engine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) # once backward has been applied, release graph closure_loss = closure_loss.detach() - return closure_loss def clip_gradients( diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 13c9b9f13ec23..789959e38908a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -265,6 +265,10 @@ def call(hook, fn, *args, **kwargs): d = {'name': hook} if args: d['args'] = args + elif hook == 'train': + # DeepSpeed calls `train(mode)` but we do not. Standardize + # https://github.com/microsoft/DeepSpeed/pull/571 + d['args'] = (True, ) if kwargs: d['kwargs'] = kwargs called.append(d) @@ -283,12 +287,13 @@ def test_epoch_end(self, *args, **kwargs): pass @staticmethod - def _train_batch(trainer, model, batches, current_epoch=0): + def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_epoch=0, **kwargs): + using_native_amp = kwargs.get('amp_backend') == 'native' out = [] for i in range(batches): out.extend([ dict(name='on_before_batch_transfer', args=(ANY, 0)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='transfer_batch_to_device', args=(ANY, device, 0)), dict(name='on_after_batch_transfer', args=(ANY, 0)), # TODO: `on_batch_{start,end}` dict(name='Callback.on_batch_start', args=(trainer, model)), @@ -301,14 +306,15 @@ def _train_batch(trainer, model, batches, current_epoch=0): dict(name='on_before_zero_grad', args=(ANY, )), dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)), # TODO: `on_before_backward` - dict(name='backward', args=(ANY, ANY, 0)), + # DeepSpeed handles backward internally + *([dict(name='backward', args=(ANY, ANY, 0))] if kwargs.get('plugins') != 'deepspeed' else []), dict(name='Callback.on_after_backward', args=(trainer, model)), dict(name='on_after_backward'), # TODO: `on_before_optimizer_step` dict( name='optimizer_step', args=(current_epoch, i, ANY, 0, ANY), - kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False) + kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp) ), dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)), dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)), @@ -317,14 +323,14 @@ def _train_batch(trainer, model, batches, current_epoch=0): return out @staticmethod - def _eval_epoch(fn, trainer, model, batches, key): + def _eval_epoch(fn, trainer, model, batches, key, device=torch.device('cpu')): outputs = {key: ANY} return [ dict(name='Callback.on_epoch_start', args=(trainer, model)), dict(name='on_epoch_start'), dict(name=f'Callback.on_{fn}_epoch_start', args=(trainer, model)), dict(name=f'on_{fn}_epoch_start'), - *HookedModel._eval_batch(fn, trainer, model, batches, key), + *HookedModel._eval_batch(fn, trainer, model, batches, key, device=device), dict(name=f'{fn}_epoch_end', args=([outputs] * batches, )), dict(name=f'Callback.on_{fn}_epoch_end', args=(trainer, model)), dict(name=f'on_{fn}_epoch_end'), @@ -333,13 +339,13 @@ def _eval_epoch(fn, trainer, model, batches, key): ] @staticmethod - def _eval_batch(fn, trainer, model, batches, key): + def _eval_batch(fn, trainer, model, batches, key, device=torch.device('cpu')): out = [] outputs = {key: ANY} for i in range(batches): out.extend([ dict(name='on_before_batch_transfer', args=(ANY, 0)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='transfer_batch_to_device', args=(ANY, device, 0)), dict(name='on_after_batch_transfer', args=(ANY, 0)), # TODO: `{,Callback}.on_batch_{start,end}` dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)), @@ -357,10 +363,10 @@ def _predict_batch(trainer, model, batches): out = [] for i in range(batches): out.extend([ - # TODO: `{,Callback}.on_batch_{start,end}` dict(name='on_before_batch_transfer', args=(ANY, 0)), dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), dict(name='on_after_batch_transfer', args=(ANY, 0)), + # TODO: `{,Callback}.on_batch_{start,end}` dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_predict_batch_start', args=(ANY, i, 0)), dict(name='forward', args=(ANY, )), @@ -372,7 +378,17 @@ def _predict_batch(trainer, model, batches): return out -def test_trainer_model_hook_system_fit(tmpdir): +@pytest.mark.parametrize( + 'kwargs', + [ + {}, + # these precision plugins modify the optimization flow, so testing them explicitly + pytest.param(dict(gpus=1, precision=16, plugins='deepspeed'), marks=RunIf(deepspeed=True, min_gpus=1)), + pytest.param(dict(gpus=1, precision=16, amp_backend='native'), marks=RunIf(amp_native=True, min_gpus=1)), + pytest.param(dict(gpus=1, precision=16, amp_backend='apex'), marks=RunIf(amp_apex=True, min_gpus=1)), + ] +) +def test_trainer_model_hook_system_fit(tmpdir, kwargs): called = [] model = HookedModel(called) callback = HookedCallback(called) @@ -385,13 +401,17 @@ def test_trainer_model_hook_system_fit(tmpdir): limit_val_batches=val_batches, progress_bar_refresh_rate=0, weights_summary=None, - callbacks=[callback] + callbacks=[callback], + **kwargs, ) + assert called == [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), ] + trainer.fit(model) + saved_ckpt = { 'callbacks': ANY, 'epoch': 1, @@ -401,19 +421,31 @@ def test_trainer_model_hook_system_fit(tmpdir): 'pytorch-lightning_version': __version__, 'state_dict': ANY, } + if kwargs.get('amp_backend') == 'native': + saved_ckpt['native_amp_scaling_state'] = ANY + elif kwargs.get('amp_backend') == 'apex': + saved_ckpt['amp_scaling_state'] = ANY + device = torch.device('cuda:0' if 'gpus' in kwargs else 'cpu') + expected = [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), dict(name='prepare_data'), dict(name='configure_callbacks'), dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + # DeepSpeed needs the batch size to figure out throughput logging + *([dict(name='train_dataloader')] if kwargs.get('plugins') == 'deepspeed' else []), dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')), dict(name='setup', kwargs=dict(stage='fit')), dict(name='configure_sharded_model'), dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), - dict(name='configure_optimizers'), + # DeepSpeed skips initializing optimizers here as they are handled via config + *([dict(name='configure_optimizers')] if kwargs.get('plugins') != 'deepspeed' else []), dict(name='Callback.on_fit_start', args=(trainer, model)), dict(name='on_fit_start'), + # TODO: explore whether DeepSpeed can have the same flow for optimizers + # DeepSpeed did not find any optimizer in the config so they are loaded here + *([dict(name='configure_optimizers')] if kwargs.get('plugins') == 'deepspeed' else []), dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)), dict(name='on_pretrain_routine_start'), dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), @@ -426,14 +458,14 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='zero_grad'), dict(name='Callback.on_validation_start', args=(trainer, model)), dict(name='on_validation_start'), - *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device), dict(name='Callback.on_validation_end', args=(trainer, model)), dict(name='on_validation_end'), - dict(name='train'), + dict(name='train', args=(True, )), dict(name='on_validation_model_train'), dict(name='Callback.on_sanity_check_end', args=(trainer, model)), # duplicate `train` because `_run_train` calls it again in case validation wasn't run - dict(name='train'), + dict(name='train', args=(True, )), dict(name='on_train_dataloader'), dict(name='train_dataloader'), dict(name='Callback.on_train_start', args=(trainer, model)), @@ -442,19 +474,19 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='on_epoch_start'), dict(name='Callback.on_train_epoch_start', args=(trainer, model)), dict(name='on_train_epoch_start'), - *model._train_batch(trainer, model, train_batches), + *model._train_batch(trainer, model, train_batches, device=device, **kwargs), dict(name='train', args=(False, )), dict(name='on_validation_model_eval'), dict(name='zero_grad'), dict(name='Callback.on_validation_start', args=(trainer, model)), dict(name='on_validation_start'), - *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device), dict(name='Callback.on_validation_end', args=(trainer, model)), # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end` dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='on_validation_end'), - dict(name='train'), + dict(name='train', args=(True, )), dict(name='on_validation_model_train'), dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)), @@ -542,7 +574,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name='on_pretrain_routine_start'), dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), dict(name='on_pretrain_routine_end'), - dict(name='train'), + dict(name='train', args=(True, )), dict(name='on_train_dataloader'), dict(name='train_dataloader'), # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` @@ -610,7 +642,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader, *model._eval_epoch(noun, trainer, model, batches, key), dict(name=f'Callback.on_{noun}_end', args=(trainer, model)), dict(name=f'on_{noun}_end'), - dict(name='train'), + dict(name='train', args=(True, )), dict(name=f'on_{noun}_model_train'), ] expected = [ diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index efe8da981c9eb..dcb4ff00b219b 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -256,7 +256,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args gpus=1, precision=16, ) - with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): + with pytest.warns(UserWarning, match='will be ignored since DeepSpeed handles the backward'): trainer.fit(model)