Skip to content

Commit

Permalink
Merge pull request #207 from wilxy/dev-1.x
Browse files Browse the repository at this point in the history
[Feature] Add connector components and FitNet
  • Loading branch information
sunnyxiaohu authored Jul 28, 2022
2 parents 6987511 + d75d120 commit 63d46b5
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 24 deletions.
48 changes: 48 additions & 0 deletions configs/distill/mmcls/fitnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# FitNets

> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550)
<!-- [ALGORITHM] -->

## Abstract

While depth tends to improve network performances, it also makes gradient-based
training more difficult since deeper networks tend to be more non-linear. The recently
proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute
models, and it has shown that a student network could imitate the soft output of a larger
teacher network or ensemble of networks. In this paper, we extend this idea to allow the
training of a student that is deeper and thinner than the teacher, using not only the outputs
but also the intermediate representations learned by the teacher as hints to improve the
training process and final performance of the student. Because the student intermediate hidden
layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters
are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This
allows one to train deeper students that can generalize better or run faster, a trade-off that is
controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with
almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network.

![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:journals/corr/RomeroBKCGB14,
author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio},
editor = {Yoshua Bengio and Yann LeCun},
title = {FitNets: Hints for Thin Deep Nets},
booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings},
year = {2015},
url = {http://arxiv.org/abs/1412.6550},
timestamp = {Thu, 25 Jul 2019 14:25:38 +0200},
biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_s4=dict(type='L2Loss', loss_weight=10),
loss_s3=dict(type='L2Loss', loss_weight=10),
loss_kl=dict(
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')),
connectors=dict(
loss_s4_sfeat=dict(
type='ConvBNReLUConnector',
in_channel=512,
out_channel=2048,
norm_cfg=dict(type='BN')),
loss_s3_sfeat=dict(
type='ConvBNReLUConnector',
in_channel=256,
out_channel=1024,
norm_cfg=dict(type='BN'))),
loss_forward_mappings=dict(
loss_s4=dict(
s_feature=dict(
from_student=True,
recorder='bb_s4',
record_idx=1,
connector='loss_s4_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s4', record_idx=2)),
loss_s3=dict(
s_feature=dict(
from_student=True,
recorder='bb_s3',
record_idx=1,
connector='loss_s3_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s3', record_idx=2)),
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Binary file added docs/en/imgs/model_zoo/fitnet/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mmcv.runner import load_checkpoint
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.utils import add_prefix
Expand All @@ -18,6 +19,7 @@ class SingleTeacherDistill(BaseAlgorithm):
only use one teacher.
Args:
distiller (dict): The config dict for built distiller.
teacher (dict | BaseModel): The config dict for teacher model or built
teacher model.
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
Expand All @@ -26,6 +28,10 @@ class SingleTeacherDistill(BaseAlgorithm):
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
"""

def __init__(self,
Expand All @@ -34,7 +40,9 @@ def __init__(self,
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
**kwargs):
student_trainable: bool = True,
calculate_student_loss: bool = True,
**kwargs) -> None:
super().__init__(**kwargs)

self.distiller = MODELS.build(distiller)
Expand All @@ -55,13 +63,21 @@ def __init__(self,
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval

# The student model will not calculate gradients and update parameters
# in some pretraining process.
self.student_trainable = student_trainable

# The student loss will not be updated into ``losses`` in some
# pretraining process.
self.calculate_student_loss = calculate_student_loss

# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

@property
def student(self):
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture

Expand All @@ -86,16 +102,25 @@ def loss(
else:
with self.distiller.teacher_recorders, self.distiller.deliveries:
with torch.no_grad():

_ = self.teacher(batch_inputs, data_samples, mode='loss')

# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
# Original task loss will not be used during some pretraining process.
if self.calculate_student_loss:
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='loss')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='loss')

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
Expand All @@ -104,7 +129,7 @@ def loss(

return losses

def train(self, mode=True):
def train(self, mode: bool = True) -> None:
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
Expand Down
1 change: 1 addition & 0 deletions mmrazor/models/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .dynamic_op import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
5 changes: 5 additions & 0 deletions mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .general_connector import (ConvBNConnector, ConvBNReLUConnector,
SingleConvConnector)

__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector']
41 changes: 41 additions & 0 deletions mmrazor/models/architectures/connectors/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional

import torch
from mmcv.runner import BaseModule


class BaseConnector(BaseModule, metaclass=ABCMeta):
"""Base class of connectors.
Connector is mainly used for distillation, it usually converts the channel
number of input feature to align features of student and teacher.
All subclasses should implement the following APIs:
- ``forward_train()``
Args:
init_cfg (dict, optional): The config to control the initialization.
"""

def __init__(self, init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg=init_cfg)

def forward(self, feature: torch.Tensor) -> None:
"""Forward computation.
Args:
feature (torch.Tensor): Input feature.
"""
return self.forward_train(feature)

@abstractmethod
def forward_train(self, feature) -> torch.Tensor:
"""Abstract train computation.
Args:
feature (torch.Tensor): Input feature.
"""
pass
Loading

0 comments on commit 63d46b5

Please sign in to comment.