Skip to content

Commit

Permalink
Fix summary hook handles not getting removed (#2298)
Browse files Browse the repository at this point in the history
* detach hooks after completion

* detach hook

* update docs

* add test

* docs

* changelog
  • Loading branch information
awaelchli authored Jun 20, 2020
1 parent c7f8367 commit f972ab3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))


## [0.8.1] - 2020-06-19

Expand Down
29 changes: 23 additions & 6 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -54,28 +55,42 @@ def __init__(self, module: nn.Module):
self._in_size = None
self._out_size = None

def _register_hook(self):
def __del__(self):
self.detach_hook()

def _register_hook(self) -> RemovableHandle:
"""
Registers a hook on the module that computes the input- and output size(s)
on the first forward pass. The hook will remove itself from the module, meaning that
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
If the hook is called, it will remove itself from the from the module, meaning that
recursive models will only record their input- and output shapes once.
Return:
A handle for the installed hook.
"""

def hook(module, inp, out):
if len(inp) == 1:
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
self._hook_handle.remove() # hook detaches itself from module
self._hook_handle.remove()

return self._module.register_forward_hook(hook)

def detach_hook(self):
"""
Removes the forward hook if it was not already removed in the forward pass.
Will be called after the summary is created.
"""
if self._hook_handle is not None:
self._hook_handle.remove()

@property
def in_size(self):
def in_size(self) -> Union[str, List]:
return self._in_size or UNKNOWN_SIZE

@property
def out_size(self):
def out_size(self) -> Union[str, List]:
return self._out_size or UNKNOWN_SIZE

@property
Expand Down Expand Up @@ -180,6 +195,8 @@ def summarize(self) -> Dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
if self._model.example_input_array is not None:
self._forward_example_input()
for layer in summary.values():
layer.detach_hook()
return summary

def _forward_example_input(self) -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def test_linear_model_summary_shapes(device, dtype, mode):
assert model.device == device


@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_hooks_removed_after_summarize(mode):
""" Test that all hooks were properly removed after summary, even ones that were not run. """
model = UnorderedModel()
summary = ModelSummary(model, mode=mode)
# hooks should be removed
for _, layer in summary.summarize().items():
handle = layer._hook_handle
assert handle.id not in handle.hooks_dict_ref()


@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
Expand Down

0 comments on commit f972ab3

Please sign in to comment.