Skip to content

Commit

Permalink
[bugfix] Check LightningOptimizer doesn't delete optimizer hooks (#6305)
Browse files Browse the repository at this point in the history
* update

* resolve bug
  • Loading branch information
tchaton authored and lexierule committed Mar 9, 2021
1 parent 896470e commit a39cb52
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Ensure we check deepspeed/sharded in multinode DDP ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Check `LightningOptimizer` doesn't delete optimizer hooks ([#6305](https://github.com/PyTorchLightning/pytorch-lightning/pull/6305)


## [1.2.2] - 2021-03-02

### Added
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LightningOptimizer:

def __init__(self, optimizer: Optimizer):

self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'}
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")}

# For Horovod
if hasattr(optimizer, "skip_synchronize"):
Expand Down
82 changes: 81 additions & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch, DEFAULT
import gc
from typing import Any
from unittest.mock import DEFAULT, patch

import torch
from torch.optim import Adam, Optimizer, SGD
Expand Down Expand Up @@ -188,6 +190,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
"""
Test overriding zero_grad works in automatic_optimization
"""

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx=None):
Expand Down Expand Up @@ -281,7 +284,9 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir):
Test zero_grad is called the same number of times as LBFGS requires
for reevaluation of the loss in automatic_optimization.
"""

class TestModel(BoringModel):

def configure_optimizers(self):
return torch.optim.LBFGS(self.parameters())

Expand All @@ -300,3 +305,78 @@ def configure_optimizers(self):
lbfgs = model.optimizers()
max_iter = lbfgs.param_groups[0]["max_iter"]
assert zero_grad.call_count == max_iter


class OptimizerWithHooks(Optimizer):

def __init__(self, model):
self._fwd_handles = []
self._bwd_handles = []
self.params = []
for _, mod in model.named_modules():
mod_class = mod.__class__.__name__
if mod_class != 'Linear':
continue

handle = mod.register_forward_pre_hook(self._save_input) # save the inputs
self._fwd_handles.append(handle) # collect forward-save-input hooks in list
handle = mod.register_backward_hook(self._save_grad_output) # save the gradients
self._bwd_handles.append(handle) # collect backward-save-grad hook in list

# save the parameters
params = [mod.weight]
if mod.bias is not None:
params.append(mod.bias)

# save a param_group for each module
d = {'params': params, 'mod': mod, 'layer_type': mod_class}
self.params.append(d)

super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01})

def _save_input(self, mod, i):
"""Saves input of layer"""
if mod.training:
self.state[mod]['x'] = i[0]

def _save_grad_output(self, mod, _, grad_output):
"""
Saves grad on output of layer to
grad is scaled with batch_size since gradient is spread over samples in mini batch
"""
batch_size = grad_output[0].shape[0]
if mod.training:
self.state[mod]['grad'] = grad_output[0] * batch_size

def step(self, closure=None):
closure()
for group in self.param_groups:
_ = self.state[group['mod']]['x']
_ = self.state[group['mod']]['grad']
return True


def test_lightning_optimizer_keeps_hooks(tmpdir):

class TestModel(BoringModel):
count_on_train_batch_start = 0
count_on_train_batch_end = 0

def configure_optimizers(self):
return OptimizerWithHooks(self)

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.count_on_train_batch_start += 1
optimizer = self.optimizers(use_pl_optimizer=False)
assert len(optimizer._fwd_handles) == 1

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.count_on_train_batch_end += 1
del self.trainer._lightning_optimizers
gc.collect() # not necessary, just in case

trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1)
model = TestModel()
trainer.fit(model)
assert model.count_on_train_batch_start == 4
assert model.count_on_train_batch_end == 4

0 comments on commit a39cb52

Please sign in to comment.