Skip to content

Commit

Permalink
Fix finetuning complex models correctly unfreezes. (#6880)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
scart97 and carmocca authored Apr 8, 2021
1 parent 968ac09 commit eb15abc
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879))


- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


Expand Down
8 changes: 2 additions & 6 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential
from torch.optim.optimizer import Optimizer

from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -102,11 +101,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
else:
_modules = modules.modules()

return list(
filter(
lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules
)
)
# Leaf nodes in the graph have no children, so we use that to filter
return [m for m in _modules if not list(m.children())]

@staticmethod
def filter_params(
Expand Down
39 changes: 39 additions & 0 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict

import pytest
import torch
from torch import nn
Expand Down Expand Up @@ -244,3 +246,40 @@ def configure_optimizers(self):

trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True)
trainer.fit(model)


def test_deep_nested_model():

class ConvBlock(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

model = nn.Sequential(
OrderedDict([
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
("decoder", ConvBlock(128, 10)),
])
)

# There's 9 leaf layers in that model
assert len(BaseFinetuning.flatten_modules(model)) == 9

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad
assert model.encoder[0].bn.weight.requires_grad

BaseFinetuning.make_trainable(model)
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
# The 8 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
assert len(encoder_params) == 8

0 comments on commit eb15abc

Please sign in to comment.