Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_depth parameter to ModelSummary #8062

Merged
merged 42 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
aaef74d
Add max_depth parameter to ModelSummary
ManuelPalermo Jun 21, 2021
9a97169
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
2752821
Better test coverage + fix pep8 on ModelSummary max_depth
ManuelPalermo Jun 21, 2021
52d1c23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
f5919dc
Apply suggestions from code review
Borda Jun 21, 2021
8771707
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
9d05d7f
Removed mode parameter inside ModelSummary
ManuelPalermo Jun 21, 2021
af2aaa4
Removed mode parameter inside ModelSummary
ManuelPalermo Jun 21, 2021
658140f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
1afd722
Added deprecation warning to PL.summarize(mode=)
ManuelPalermo Jun 21, 2021
cfc27b4
Merge remote-tracking branch 'origin/master'
ManuelPalermo Jun 21, 2021
a1ec6cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
2b67abb
Improve deprecation of `mode` param in ModelSummary
ManuelPalermo Jun 24, 2021
bd9f4e6
Merge remote-tracking branch 'origin/master'
ManuelPalermo Jun 24, 2021
a1c48c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
fc78dc7
Add missing import
ManuelPalermo Jun 24, 2021
72760a9
Merge remote-tracking branch 'origin/master'
ManuelPalermo Jun 24, 2021
32df3e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
4750c3e
Merge branch 'master' into feature/max_depth
awaelchli Jun 29, 2021
a22c78b
Merge branch 'master' into feature/max_depth
awaelchli Jun 29, 2021
2f9e770
update changelog
awaelchli Jun 29, 2021
e0b1d54
update deprecation message
awaelchli Jun 29, 2021
6af47fd
update docs and deprecation
awaelchli Jun 29, 2021
d3f49eb
update test
awaelchli Jun 29, 2021
21f222f
add deprecation test
awaelchli Jun 29, 2021
f0a07a6
rm notebooks
awaelchli Jun 29, 2021
32518a5
Change ModelSummary max_depth idx logic:
ManuelPalermo Jun 29, 2021
5551f54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2021
55188a8
fix pep8 issue
ManuelPalermo Jun 29, 2021
5dc7632
Merge branch 'master' into master
ManuelPalermo Jun 29, 2021
e0d8b7b
improve deprecation warnings + default values
ManuelPalermo Jun 29, 2021
62a580c
pep8 fix
ManuelPalermo Jun 29, 2021
6792f3b
added missing model_summary log info
ManuelPalermo Jun 30, 2021
05377bf
Merge branch 'master' into master
ManuelPalermo Jun 30, 2021
30f2bb7
move deprecation test
awaelchli Jun 30, 2021
4d638ec
compact deprecation message
awaelchli Jun 30, 2021
5cf8f05
move log.info to the end of method
awaelchli Jun 30, 2021
ad01312
update changelog
awaelchli Jun 30, 2021
7dea235
Revert "compact deprecation message"
awaelchli Jun 30, 2021
224ef45
compact deprecation message
awaelchli Jul 1, 2021
14d4d33
update deprecated tag
awaelchli Jul 1, 2021
a479cfa
handle max_depth validation at init
awaelchli Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))


- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))


### Changed


Expand Down Expand Up @@ -267,6 +270,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025))


- Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,15 +1642,24 @@ def tbptt_split_batch(self, batch, split_size):

return splits

def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]:
model_summary = None
ManuelPalermo marked this conversation as resolved.
Show resolved Hide resolved

if mode in ModelSummary.MODES:
model_summary = ModelSummary(self, mode=mode)
log.info("\n" + str(model_summary))
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
# temporary mapping from mode to max_depth
if max_depth is None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
rank_zero_deprecation(
f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
)
model_summary = ModelSummary(self, max_depth=max_depth)
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
else:
model_summary = ModelSummary(self, max_depth=max_depth)

log.info("\n" + str(model_summary))
return model_summary

def freeze(self) -> None:
Expand Down
58 changes: 42 additions & 16 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,17 @@ class ModelSummary(object):
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.

Args:
model: The model to summarize (also referred to as the root module)
model: The model to summarize (also referred to as the root module).
mode: Can be one of

- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module

.. deprecated:: v1.4
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.

max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no
summary. Defaults to 1.

The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
Expand All @@ -160,7 +166,7 @@ class ModelSummary(object):
... return self.net(x)
...
>>> model = LitModel()
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
Expand All @@ -169,7 +175,7 @@ class ModelSummary(object):
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
Expand All @@ -182,14 +188,28 @@ class ModelSummary(object):
0.530 Total estimated model params size (MB)
"""

MODE_TOP = "top"
MODE_FULL = "full"
MODE_DEFAULT = MODE_TOP
MODES = [MODE_FULL, MODE_TOP]
MODES = dict(top=1, full=-1) # TODO: remove in v1.6

def __init__(self, model, mode: str = MODE_DEFAULT):
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
self._model = model
self._mode = mode

# temporary mapping from mode to max_depth
if max_depth is None or mode is not None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
from pytorch_lightning.utilities import rank_zero_deprecation
rank_zero_deprecation(
f"Argument `mode` in `ModelSummary` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
)
else:
from pytorch_lightning.utilities.exceptions import MisconfigurationException
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
ManuelPalermo marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")

self._max_depth = max_depth
self._layer_summary = self.summarize()
# 1 byte -> 8 bits
# TODO: how do we compute precisin_megabytes in case of mixed precision?
Expand All @@ -198,14 +218,14 @@ def __init__(self, model, mode: str = MODE_DEFAULT):

@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
if self._mode == ModelSummary.MODE_FULL:
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
elif self._mode == ModelSummary.MODE_TOP:
if self._max_depth == 0:
mods = []
elif self._max_depth == 1:
# the children are the top-level modules
mods = self._model.named_children()
else:
mods = []
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
return list(mods)

@property
Expand Down Expand Up @@ -249,6 +269,12 @@ def summarize(self) -> Dict[str, LayerSummary]:
self._forward_example_input()
for layer in summary.values():
layer.detach_hook()

if self._max_depth >= 1:
# remove summary entries with depth > max_depth
for k in [k for k in summary if k.count(".") >= self._max_depth]:
del summary[k]

return summary

def _forward_example_input(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ def _pre_training_routine(self):

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
ref_model.summarize(mode=self.weights_summary)
max_depth = ModelSummary.MODES[self.weights_summary]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
ref_model.summarize(max_depth=max_depth)

# on pretrain routine end
self.on_pretrain_routine_end()
Expand Down
78 changes: 65 additions & 13 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ def forward(self, inp):
return self.layer2(self.layer1(inp))


class DeepNestedModel(LightningModule):
""" A model with deep nested layers. """

def __init__(self):
super().__init__()
self.branch1 = nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 3))))
)
)
)
self.branch2 = nn.Linear(5, 10)
self.head = UnorderedModel()
self.example_input_array = torch.rand(2, 5)

def forward(self, inp):
return self.head(self.branch1(inp), self.branch2(inp))


def test_invalid_weights_summmary():
""" Test that invalid value for weights_summary raises an error. """
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
Expand All @@ -123,8 +146,8 @@ def test_invalid_weights_summmary():
Trainer(weights_summary='temp')


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
def test_empty_model_summary_shapes(mode: ModelSummary):
@pytest.mark.parametrize('mode', ["full", "top"])
def test_empty_model_summary_shapes(mode: str):
""" Test that the summary works for models that have no submodules. """
model = EmptyModule()
summary = model.summarize(mode=mode)
Expand All @@ -134,7 +157,7 @@ def test_empty_model_summary_shapes(mode: ModelSummary):


@RunIf(min_gpus=1)
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
@pytest.mark.parametrize(['device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda', 0)),
Expand Down Expand Up @@ -177,18 +200,18 @@ def test_mixed_dtype_model_summary():
]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
def test_hooks_removed_after_summarize(mode):
@pytest.mark.parametrize('max_depth', [-1, 0])
def test_hooks_removed_after_summarize(max_depth):
""" Test that all hooks were properly removed after summary, even ones that were not run. """
model = UnorderedModel()
summary = ModelSummary(model, mode=mode)
summary = ModelSummary(model, max_depth=max_depth)
# hooks should be removed
for _, layer in summary.summarize().items():
handle = layer._hook_handle
assert handle.id not in handle.hooks_dict_ref()


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_rnn_summary_shapes(mode):
""" Test that the model summary works for RNNs. """
model = ParityModuleRNN()
Expand All @@ -212,7 +235,7 @@ def test_rnn_summary_shapes(mode):
]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_parameter_count(mode):
""" Test that the summary counts the number of parameters in every submodule. """
model = UnorderedModel()
Expand All @@ -226,7 +249,7 @@ def test_summary_parameter_count(mode):
]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_layer_types(mode):
""" Test that the summary displays the layer names correctly. """
model = UnorderedModel()
Expand All @@ -240,7 +263,7 @@ def test_summary_layer_types(mode):
]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_with_scripted_modules(mode):
model = PartialScriptModel()
summary = model.summarize(mode=mode)
Expand All @@ -249,7 +272,7 @@ def test_summary_with_scripted_modules(mode):
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
@pytest.mark.parametrize(['example_input', 'expected_size'], [
pytest.param([], UNKNOWN_SIZE),
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
Expand Down Expand Up @@ -283,15 +306,15 @@ def forward(self, *args, **kwargs):
assert summary.in_sizes == [expected_size]


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_model_size(mode):
""" Test model size is calculated correctly. """
model = PreCalculatedModel()
summary = model.summarize(mode=mode)
assert model.pre_calculated_model_size == summary.model_size


@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_empty_model_size(mode):
""" Test empty model size is zero. """
model = EmptyModule()
Expand Down Expand Up @@ -336,3 +359,32 @@ def test_lazy_model_summary():
# https://github.com/pytorch/pytorch/issues/58350
assert summary.total_parameters == 7
assert summary.trainable_parameters == 7


def test_max_depth_equals_mode_interface():
"""Test model.summarize(full/top) interface mapping matches max_depth"""
model = DeepNestedModel()

summary_top = model.summarize(mode="top")
summary_0 = model.summarize(max_depth=1)
assert str(summary_top) == str(summary_0)

summary_full = model.summarize(mode="full")
summary_minus1 = model.summarize(max_depth=-1)
assert str(summary_full) == str(summary_minus1)


@pytest.mark.parametrize('max_depth', [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown"""
model = DeepNestedModel()
summary = ModelSummary(model, max_depth=max_depth)
for lname in summary.layer_names:
if max_depth >= 0:
assert lname.count(".") < max_depth


@pytest.mark.parametrize('max_depth', [-99, -2, "invalid"])
def test_raise_invalid_max_depth_value(max_depth):
with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"):
DeepNestedModel().summarize(max_depth=max_depth)
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -249,3 +250,12 @@ def test_v1_6_0_ddp_plugin_task_idx():
plugin = DDPPlugin()
with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'):
_ = plugin.task_idx


def test_v1_6_0_deprecated_model_summary_mode(tmpdir):
model = BoringModel()
with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"):
ModelSummary(model, mode="top")

with pytest.deprecated_call(match="Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"):
model.summarize(mode="top")