Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Support for specifying epochs to stop knowledge distillation #455

Merged
merged 11 commits into from
Mar 1, 2023
5 changes: 3 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import DumpSubnetHook, EstimateResourcesHook
from .hooks import (DistillationLossDetachHook, DumpSubnetHook,
EstimateResourcesHook)
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
Expand All @@ -12,5 +13,5 @@
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop',
'AutoSlimGreedySearchLoop', 'SubnetValLoop'
'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'DistillationLossDetachHook'
]
6 changes: 5 additions & 1 deletion mmrazor/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distillation_loss_detach_hook import DistillationLossDetachHook
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
from .visualization_hook import RazorVisualizationHook

__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook']
__all__ = [
'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook',
'DistillationLossDetachHook'
]
25 changes: 25 additions & 0 deletions mmrazor/engine/hooks/distillation_loss_detach_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmrazor.registry import HOOKS


@HOOKS.register_module()
class DistillationLossDetachHook(Hook):
pppppM marked this conversation as resolved.
Show resolved Hide resolved

priority = 'LOW'

def __init__(self, detach_epoch) -> None:
self.detach_epoch = detach_epoch

def before_train_epoch(self, runner) -> None:
if runner.epoch == self.detach_epoch:
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
assert hasattr(model, 'distill_loss_detach')

runner.logger.info('Distillation stop now!')
model.distill_loss_detach = True
pppppM marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def loss(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
if not self.distill_loss_detach:
# Automatically compute distill losses based on
# `loss_forward_mappings`.
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))

return losses
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def __init__(self,
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

# may be modified by distill loss scheduler hook
self.distill_loss_detach = False

@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
Expand Down Expand Up @@ -135,10 +138,12 @@ def loss(
_ = self.student(
batch_inputs, data_samples, mode='loss')

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
if not self.distill_loss_detach:
# Automatically compute distill losses based on
# `loss_forward_mappings`.
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))

return losses

Expand Down
4 changes: 3 additions & 1 deletion mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .mgd_connector import MGDConnector
from .norm_connector import NormConnector
from .ofd_connector import OFDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector

__all__ = [
'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector'
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector',
'NormConnector'
]
19 changes: 19 additions & 0 deletions mmrazor/models/architectures/connectors/norm_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

import torch
from mmcv.cnn import build_norm_layer

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class NormConnector(BaseConnector):

def __init__(self, in_channels, norm_cfg, init_cfg: Optional[Dict] = None):
super(NormConnector, self).__init__(init_cfg)
_, self.norm = build_norm_layer(norm_cfg, in_channels)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
return self.norm(feature)
17 changes: 13 additions & 4 deletions mmrazor/models/distillers/configurable_distiller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from inspect import signature
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from mmengine.model import BaseModel
from torch import nn
Expand Down Expand Up @@ -139,15 +139,24 @@ def prepare_from_teacher(self, model: nn.Module) -> None:

def build_connectors(
self,
connectors: Optional[Dict[str, Dict]] = None,
connectors: Optional[Union[Dict[str, List], Dict[str, Dict]]] = None,
) -> nn.ModuleDict:
"""Initialize connectors."""

distill_connecotrs = nn.ModuleDict()
if connectors:
for connector_name, connector_cfg in connectors.items():
connector = MODELS.build(connector_cfg)
distill_connecotrs[connector_name] = connector
if isinstance(connector_cfg, dict):
connector = MODELS.build(connector_cfg)
distill_connecotrs[connector_name] = connector
else:
assert isinstance(connector_cfg, list)
module_list = []
for cfg in connector_cfg:
connector = MODELS.build(cfg)
module_list.append(connector)
distill_connecotrs[connector_name] = nn.Sequential(
*module_list)

return distill_connecotrs

Expand Down
6 changes: 1 addition & 5 deletions mmrazor/models/losses/cwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ class ChannelWiseDivergence(nn.Module):
loss_weight (float): Weight of loss. Defaults to 1.0.
"""

def __init__(
self,
tau=1.0,
loss_weight=1.0,
):
def __init__(self, tau=1.0, loss_weight=1.0):
super(ChannelWiseDivergence, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
Expand Down
27 changes: 27 additions & 0 deletions tests/test_engine/test_hooks/test_distillation_loss_detach_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock

from mmrazor.engine import DistillationLossDetachHook


class TestDistillationLossDetachHook(TestCase):

def setUp(self):
self.hook = DistillationLossDetachHook(detach_epoch=5)
runner = Mock()
runner.model = Mock()
runner.model.distill_loss_detach = False

runner.epoch = 0
# runner.max_epochs = 10
self.runner = runner

def test_before_train_epoch(self):
max_epochs = 10
target = [False] * 5 + [True] * 5
for epoch in range(max_epochs):
self.hook.before_train_epoch(self.runner)
self.assertEquals(self.runner.model.distill_loss_detach,
target[epoch])
self.runner.epoch += 1
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mmrazor.models import (BYOTConnector, ConvModuleConnector, CRDConnector,
FBKDStudentConnector, FBKDTeacherConnector,
MGDConnector, Paraphraser,
MGDConnector, NormConnector, Paraphraser,
TorchFunctionalConnector, TorchNNConnector,
Translator)

Expand Down Expand Up @@ -143,3 +143,11 @@ def test_mgd_connector(self):

assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])

def test_norm_connector(self):
s_feat = torch.randn(2, 3, 2, 2)
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
norm_connector = NormConnector(3, norm_cfg)
output = norm_connector.forward_train(s_feat)

assert output.shape == torch.Size([2, 3, 2, 2])
48 changes: 48 additions & 0 deletions tests/test_models/test_distillers/test_configurable_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,31 @@
import copy
from unittest import TestCase

import torch
import torch.nn as nn
from mmengine import ConfigDict

from mmrazor.models import ConfigurableDistiller
from mmrazor.registry import MODELS


class ToyDistillLoss(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, arg1, arg2):
return arg1 + arg2


class TestConfigurableDistiller(TestCase):

def setUp(self):
MODELS.register_module(module=ToyDistillLoss, force=True)

def tearDown(self):
MODELS.module_dict.pop('ToyDistillLoss')

def test_init(self):

recorders_cfg = ConfigDict(
Expand Down Expand Up @@ -65,3 +83,33 @@ def test_init(self):
with self.assertRaisesRegex(TypeError,
'from_student should be a bool'):
_ = ConfigurableDistiller(**distiller_kwargs_)

def test_connector_list(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)

distiller_kwargs = ConfigDict(
student_recorders=recorders_cfg,
teacher_recorders=recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(
from_student=True,
recorder='conv',
connector='loss_1_sfeat'),
arg2=dict(from_student=False, recorder='conv'),
)),
connectors=dict(loss_1_sfeat=[
dict(
type='ConvModuleConnector',
in_channel=3,
out_channel=4,
act_cfg=None),
dict(type='NormConnector', norm_cfg=norm_cfg, in_channels=4)
]))

distiller = ConfigurableDistiller(**distiller_kwargs)
connectors = distiller.connectors
self.assertIsInstance(connectors['loss_1_sfeat'], nn.Sequential)