diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 8c7d473a76738..3c6e34df8d5e3 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -17,7 +17,6 @@ import torch from torch.optim import Adam, Optimizer, SGD -from torch.optim.optimizer import _RequiredParameter from pytorch_lightning import Trainer from pytorch_lightning.core.optimizer import LightningOptimizer @@ -308,31 +307,17 @@ def configure_optimizers(self): assert zero_grad.call_count == max_iter -required = _RequiredParameter() - - class OptimizerWithHooks(Optimizer): - def __init__(self, model, lr=required, u0=required): - if lr is not required and lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - - defaults = dict(lr=lr) - self.steps = 0 - - self.params = [] - + def __init__(self, model): self._fwd_handles = [] self._bwd_handles = [] - - self.model = model - - for _, mod in model.named_modules(): # iterates over modules of model + self.params = [] + for _, mod in model.named_modules(): mod_class = mod.__class__.__name__ - if mod_class not in ['Linear']: # silently skips other layers + if mod_class != 'Linear': continue - # save the inputs and gradients for the kfac matrix computation handle = mod.register_forward_pre_hook(self._save_input) # save the inputs self._fwd_handles.append(handle) # collect forward-save-input hooks in list handle = mod.register_backward_hook(self._save_grad_output) # save the gradients @@ -347,21 +332,21 @@ def __init__(self, model, lr=required, u0=required): d = {'params': params, 'mod': mod, 'layer_type': mod_class} self.params.append(d) - super(OptimizerWithHooks, self).__init__(self.params, defaults) + super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01}) def _save_input(self, mod, i): """Saves input of layer""" if mod.training: self.state[mod]['x'] = i[0] - def _save_grad_output(self, mod, grad_input, grad_output): + def _save_grad_output(self, mod, _, grad_output): """ Saves grad on output of layer to grad is scaled with batch_size since gradient is spread over samples in mini batch """ - bs = grad_output[0].shape[0] # batch_size + batch_size = grad_output[0].shape[0] if mod.training: - self.state[mod]['grad'] = grad_output[0] * bs + self.state[mod]['grad'] = grad_output[0] * batch_size def step(self, closure=None): closure() @@ -371,14 +356,11 @@ def step(self, closure=None): return True -def test_lightning_optimizer_dont_delete_wrapped_optimizer(tmpdir): +def test_lightning_optimizer_keeps_hooks(tmpdir): class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.count_on_train_batch_start = 0 - self.count_on_train_batch_end = 0 + count_on_train_batch_start = 0 + count_on_train_batch_end = 0 def configure_optimizers(self): return OptimizerWithHooks(self) @@ -390,15 +372,11 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: self.count_on_train_batch_end += 1 - # delete the lightning_optimizers - self.trainer._lightning_optimizers = None - gc.collect() + del self.trainer._lightning_optimizers + gc.collect() # not necessary, just in case + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1) model = TestModel() - # Initialize a trainer - trainer = Trainer(limit_train_batches=4, limit_val_batches=1, max_epochs=1) - - # Train the model ⚡ trainer.fit(model) assert model.count_on_train_batch_start == 4 assert model.count_on_train_batch_end == 4