Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add connector components and FitNet #207

Merged
merged 6 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/distill/mmcls/fitnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# FitNets

> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550)
<!-- [ALGORITHM] -->

## Abstract

While depth tends to improve network performances, it also makes gradient-based
training more difficult since deeper networks tend to be more non-linear. The recently
proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute
models, and it has shown that a student network could imitate the soft output of a larger
teacher network or ensemble of networks. In this paper, we extend this idea to allow the
training of a student that is deeper and thinner than the teacher, using not only the outputs
but also the intermediate representations learned by the teacher as hints to improve the
training process and final performance of the student. Because the student intermediate hidden
layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters
are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This
allows one to train deeper students that can generalize better or run faster, a trade-off that is
controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with
almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network.

![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:journals/corr/RomeroBKCGB14,
author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio},
editor = {Yoshua Bengio and Yann LeCun},
title = {FitNets: Hints for Thin Deep Nets},
booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings},
year = {2015},
url = {http://arxiv.org/abs/1412.6550},
timestamp = {Thu, 25 Jul 2019 14:25:38 +0200},
biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
feat_4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
feat_3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
feat_4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'),
feat_3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_f4=dict(type='L2Loss', loss_weight=10),
loss_f3=dict(type='L2Loss', loss_weight=10),
loss_kl=dict(
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')),
student_connectors=dict(
loss_f4=dict(
type='ReLUConnector', in_channel=512, out_channel=2048),
loss_f3=dict(
type='ReLUConnector', in_channel=256, out_channel=1024)),
loss_forward_mappings=dict(
loss_f4=dict(
s_feature=dict(
from_student=True, recorder='feat_4', record_idx=1),
t_feature=dict(
from_student=False, recorder='feat_4', record_idx=2)),
loss_f3=dict(
s_feature=dict(
from_student=True, recorder='feat_3', record_idx=1),
t_feature=dict(
from_student=False, recorder='feat_3', record_idx=2)),
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Binary file added docs/en/imgs/model_zoo/fitnet/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mmrazor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .algorithms import * # noqa: F401,F403
from .architectures import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .distillers import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .mutables import * # noqa: F401,F403
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mmcv.runner import load_checkpoint
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.utils import add_prefix
Expand All @@ -26,6 +27,8 @@ class SingleTeacherDistill(BaseAlgorithm):
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
"""

def __init__(self,
Expand All @@ -34,7 +37,8 @@ def __init__(self,
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
**kwargs):
student_trainable: bool = True,
**kwargs) -> None:
super().__init__(**kwargs)

self.distiller = MODELS.build(distiller)
Expand All @@ -55,13 +59,17 @@ def __init__(self,
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval

# The student model will not calculate gradients and update parameters
# during some pretraining process.
self.student_trainable = student_trainable

# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

@property
def student(self):
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture

Expand All @@ -86,16 +94,25 @@ def loss(
else:
with self.distiller.teacher_recorders, self.distiller.deliveries:
with torch.no_grad():

_ = self.teacher(batch_inputs, data_samples, mode='loss')

# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
# Original task loss will not be used during some pretraining process.
if self.distiller.calculate_student_loss:
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='loss')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='loss')

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
Expand All @@ -104,7 +121,7 @@ def loss(

return losses

def train(self, mode=True):
def train(self, mode: bool = True) -> None:
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
Expand Down
4 changes: 4 additions & 0 deletions mmrazor/models/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .general_connector import BNConnector, ReLUConnector, SingleConvConnector

__all__ = ['BNConnector', 'ReLUConnector', 'SingleConvConnector']
37 changes: 37 additions & 0 deletions mmrazor/models/connectors/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod

import torch
import torch.nn as nn


class BaseConnector(nn.Module, metaclass=ABCMeta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to inherit from BaseModule?
Then, self.init_parameters() could be rewritten by init_weights() and init_cfg

"""Base class of connectors.

Connector is mainly used for distill, it usually converts the channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distill -> distillation

number of input feature to align features of student and teacher.

All subclasses should implement the following APIs:

- ``forward_train()``
"""

def __init__(self) -> None:
super().__init__()

def forward(self, feature: torch.Tensor) -> None:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
return self.forward_train(feature)

@abstractmethod
def forward_train(self, feature) -> torch.Tensor:
"""Abstract train computation.

Args:
feature (torch.Tensor): Input feature.
"""
pass
110 changes: 110 additions & 0 deletions mmrazor/models/connectors/general_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class SingleConvConnector(BaseConnector):
"""General connector which only contains a conv layer.

Args:
in_channel (int): The input channel of the connector.
out_channel (int): The output channel of the connector.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use build_conv_layer to build a conv layer

self.init_parameters()

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
return self.conv(feature)

def init_parameters(self) -> None:
"""Init parameters."""
with torch.no_grad():
for m in self.modules():
if isinstance(m, nn.Conv2d):
device = m.weight.device
in_channels, _, k1, k2 = m.weight.shape
m.weight[:] = torch.randn(
m.weight.shape, device=device) / np.sqrt(
k1 * k2 * in_channels) * 1e-4
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue


@MODELS.register_module()
class BNConnector(BaseConnector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> ConvBNConnnector

"""General connector which contains a conv layer with BN.

Args:
in_channel (int): The input channels of the connector.
out_channel (int): The output channels of the connector.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use build_conv_layer to build a conv layer
build_norm_layer to build a norm layer

in_channel,
out_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn = nn.BatchNorm2d(out_channel)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
return self.bn(self.conv(feature))


@MODELS.register_module()
class ReLUConnector(BaseConnector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->ConvBNReLUConnector?

"""General connector which contains a conv layer with BN and ReLU.

Args:
in_channel (int): The input channels of the connector.
out_channel (int): The output channels of the connector.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same

self.bn = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU(inplace=True)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
return self.relu(self.bn(self.conv(feature)))
14 changes: 12 additions & 2 deletions mmrazor/models/distillers/base_distiller.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Dict, Optional

from mmengine.model import BaseModule

from ..algorithms.base import LossResults


class BaseDistiller(BaseModule, ABC):
"""Base class for distiller."""
"""Base class for distiller.

def __init__(self, init_cfg=None):
Args:
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
init_cfg (dict, optional): Config for distiller. Default to None.
"""

def __init__(self,
calculate_student_loss: bool = True,
init_cfg: Optional[Dict] = None):
super().__init__(init_cfg)
self.calculate_student_loss = calculate_student_loss

@abstractmethod
def compute_distill_losses(self) -> LossResults:
Expand Down
Loading