Skip to content

Commit

Permalink
Handle torch.jit scripted modules in layer summary (#6511)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Mar 15, 2021
1 parent 0544efd commit 02fa32b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import shutil
import subprocess
from collections import OrderedDict
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -71,14 +71,15 @@ def __init__(self, module: nn.Module):
def __del__(self):
self.detach_hook()

def _register_hook(self) -> RemovableHandle:
def _register_hook(self) -> Optional[RemovableHandle]:
"""
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
If the hook is called, it will remove itself from the from the module, meaning that
recursive models will only record their input- and output shapes once.
Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
Return:
A handle for the installed hook.
A handle for the installed hook, or ``None`` if registering the hook is not possible.
"""

def hook(module, inp, out):
Expand All @@ -88,7 +89,10 @@ def hook(module, inp, out):
self._out_size = parse_batch_shape(out)
self._hook_handle.remove()

return self._module.register_forward_hook(hook)
handle = None
if not isinstance(self._module, torch.jit.ScriptModule):
handle = self._module.register_forward_hook(hook)
return handle

def detach_hook(self):
"""
Expand Down
24 changes: 23 additions & 1 deletion tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def forward(self, x):
return self.reduce(self.embed(x))


class PartialScriptModel(LightningModule):
""" A model which contains scripted layers. """

def __init__(self):
super().__init__()
self.layer1 = torch.jit.script(nn.Linear(5, 3))
self.layer2 = nn.Linear(3, 2)
self.example_input_array = torch.rand(2, 5)

def forward(self, x):
return self.layer2(self.layer1(x))


def test_invalid_weights_summmary():
""" Test that invalid value for weights_summary raises an error. """
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
Expand Down Expand Up @@ -214,6 +227,15 @@ def test_summary_layer_types(mode):
]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
def test_summary_with_scripted_modules(mode):
model = PartialScriptModel()
summary = model.summarize(mode=mode)
assert summary.layer_types == ["RecursiveScriptModule", "Linear"]
assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]]
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize(['example_input', 'expected_size'], [
pytest.param([], UNKNOWN_SIZE),
Expand Down Expand Up @@ -265,7 +287,7 @@ def test_empty_model_size(mode):


@RunIf(min_gpus=1, amp_native=True)
def test_model_size_precision(monkeypatch, tmpdir):
def test_model_size_precision(tmpdir):
""" Test model size for half and full precision. """
model = PreCalculatedModel()

Expand Down

0 comments on commit 02fa32b

Please sign in to comment.