Skip to content

Commit

Permalink
Fix ModelPruning(make_pruning_permanent=True) buffers getting removed…
Browse files Browse the repository at this point in the history
… when saved during training (#6073)

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
carmocca and tchaton authored Mar 3, 2021
1 parent dcec4ef commit 4a8422c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))


- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))


Expand Down
43 changes: 27 additions & 16 deletions pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import logging
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, Dict

import torch
import torch.nn.utils.prune as pytorch_prune
from torch import nn

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -248,14 +248,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
def _wrap_pruning_fn(pruning_fn, **kwargs):
return partial(pruning_fn, **kwargs)

def make_pruning_permanent(self):
""" Makes ``parameters_to_prune`` current pruning permanent. """
for module, param_name in self._parameters_to_prune:
try:
pytorch_prune.remove(module, param_name)
except ValueError:
# pruning already made permanent
pass
def make_pruning_permanent(self, pl_module: LightningModule):
"""
Removes pruning buffers from any pruned modules
Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
"""
for _, module in pl_module.named_modules():
for k in list(module._forward_pre_hooks):
hook = module._forward_pre_hooks[k]
if isinstance(hook, pytorch_prune.BasePruningMethod):
hook.remove(module)
del module._forward_pre_hooks[k]

def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
trained = getattr(module, tensor_name)
Expand Down Expand Up @@ -353,7 +357,7 @@ def _log_sparsity_stats(
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)

def on_before_accelerator_backend_setup(self, trainer, pl_module):
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
)
Expand All @@ -369,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module, *args):
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand All @@ -383,13 +387,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
):
self.apply_lottery_ticket_hypothesis()

def on_train_end(self, *args):
def on_train_end(self, trainer, pl_module: LightningModule):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
self.make_pruning_permanent(pl_module)

def on_save_checkpoint(self, *args):
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
prev_device = pl_module.device
# prune a copy so training can continue with the same buffers
copy = deepcopy(pl_module.to("cpu"))
self.make_pruning_permanent(copy)
checkpoint["state_dict"] = copy.state_dict()
pl_module.to(prev_device)

@staticmethod
def sanitize_parameters_to_prune(
Expand Down
57 changes: 51 additions & 6 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
from collections import OrderedDict
from logging import INFO
from unittest import mock

import pytest
import torch
Expand All @@ -23,7 +22,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand All @@ -42,6 +41,10 @@ def __init__(self):
])
)

def training_step(self, batch, batch_idx):
self.log("test", -batch_idx)
return super().training_step(batch, batch_idx)


class TestPruningMethod(pytorch_prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
Expand Down Expand Up @@ -216,7 +219,6 @@ def apply_lottery_ticket_hypothesis(self):


@pytest.mark.parametrize("make_pruning_permanent", (False, True))
@mock.patch.dict(os.environ, {}, clear=True)
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
seed_everything(0)
model = TestModel()
Expand All @@ -241,8 +243,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages[-9:]]
expected = [
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
Expand All @@ -253,11 +256,53 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
]
assert actual == expected

filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_from_checkpoint(filepath, strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning


def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
"""
When a model is saved multiple times and make_permanent=True, we need to
make sure a copy is pruned and not the trained model if we want to continue
with the same pruning buffers.
"""
seed_everything(0)

class TestPruning(ModelPruning):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
assert hasattr(pl_module.layer.mlp_3, "weight_orig")

model = TestModel()
pruning_callback = TestPruning(
"random_unstructured",
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True
)
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
]

# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")

model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")

0 comments on commit 4a8422c

Please sign in to comment.