diff --git a/configs/distill/mmcls/fitnet/README.md b/configs/distill/mmcls/fitnet/README.md new file mode 100644 index 000000000..23cfe1d2d --- /dev/null +++ b/configs/distill/mmcls/fitnet/README.md @@ -0,0 +1,48 @@ +# FitNets + +> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550) + + + +## Abstract + +While depth tends to improve network performances, it also makes gradient-based +training more difficult since deeper networks tend to be more non-linear. The recently +proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute +models, and it has shown that a student network could imitate the soft output of a larger +teacher network or ensemble of networks. In this paper, we extend this idea to allow the +training of a student that is deeper and thinner than the teacher, using not only the outputs +but also the intermediate representations learned by the teacher as hints to improve the +training process and final performance of the student. Because the student intermediate hidden +layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters +are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This +allows one to train deeper students that can generalize better or run faster, a trade-off that is +controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with +almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network. + +![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png) + +## Results and models + +### Classification + +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | +| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | + +## Citation + +```latex +@inproceedings{DBLP:journals/corr/RomeroBKCGB14, + author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio}, + editor = {Yoshua Bengio and Yann LeCun}, + title = {FitNets: Hints for Thin Deep Nets}, + booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, + San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings}, + year = {2015}, + url = {http://arxiv.org/abs/1412.6550}, + timestamp = {Thu, 25 Jul 2019 14:25:38 +0200}, + biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py b/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..b46300e73 --- /dev/null +++ b/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py @@ -0,0 +1,71 @@ +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs32.py', + 'mmcls::_base_/schedules/imagenet_bs256.py', + 'mmcls::_base_/default_runtime.py' +] + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), + teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), + bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'), + fc=dict(type='ModuleOutputs', source='head.fc')), + teacher_recorders=dict( + bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'), + bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'), + fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict( + loss_s4=dict(type='L2Loss', loss_weight=10), + loss_s3=dict(type='L2Loss', loss_weight=10), + loss_kl=dict( + type='KLDivergence', tau=6, loss_weight=10, reduction='mean')), + connectors=dict( + loss_s4_sfeat=dict( + type='ConvBNReLUConnector', + in_channel=512, + out_channel=2048, + norm_cfg=dict(type='BN')), + loss_s3_sfeat=dict( + type='ConvBNReLUConnector', + in_channel=256, + out_channel=1024, + norm_cfg=dict(type='BN'))), + loss_forward_mappings=dict( + loss_s4=dict( + s_feature=dict( + from_student=True, + recorder='bb_s4', + record_idx=1, + connector='loss_s4_sfeat'), + t_feature=dict( + from_student=False, recorder='bb_s4', record_idx=2)), + loss_s3=dict( + s_feature=dict( + from_student=True, + recorder='bb_s3', + record_idx=1, + connector='loss_s3_sfeat'), + t_feature=dict( + from_student=False, recorder='bb_s3', record_idx=2)), + loss_kl=dict( + preds_S=dict(from_student=True, recorder='fc'), + preds_T=dict(from_student=False, recorder='fc'))))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/docs/en/imgs/model_zoo/fitnet/pipeline.png b/docs/en/imgs/model_zoo/fitnet/pipeline.png new file mode 100644 index 000000000..19662f882 Binary files /dev/null and b/docs/en/imgs/model_zoo/fitnet/pipeline.png differ diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index ef3c44246..931cbc169 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -5,6 +5,7 @@ from mmcv.runner import load_checkpoint from mmengine import BaseDataElement from mmengine.model import BaseModel +from torch import nn from torch.nn.modules.batchnorm import _BatchNorm from mmrazor.models.utils import add_prefix @@ -18,6 +19,7 @@ class SingleTeacherDistill(BaseAlgorithm): only use one teacher. Args: + distiller (dict): The config dict for built distiller. teacher (dict | BaseModel): The config dict for teacher model or built teacher model. teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None. @@ -26,6 +28,10 @@ class SingleTeacherDistill(BaseAlgorithm): teacher_norm_eval (bool): Whether to set teacher's norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to True. + student_trainable (bool): Whether the student is trainable. Defaults + to True. + calculate_student_loss (bool): Whether to calculate student loss + (original task loss) to update student model. Defaults to True. """ def __init__(self, @@ -34,7 +40,9 @@ def __init__(self, teacher_ckpt: Optional[str] = None, teacher_trainable: bool = False, teacher_norm_eval: bool = True, - **kwargs): + student_trainable: bool = True, + calculate_student_loss: bool = True, + **kwargs) -> None: super().__init__(**kwargs) self.distiller = MODELS.build(distiller) @@ -55,13 +63,21 @@ def __init__(self, self.teacher_trainable = teacher_trainable self.teacher_norm_eval = teacher_norm_eval + # The student model will not calculate gradients and update parameters + # in some pretraining process. + self.student_trainable = student_trainable + + # The student loss will not be updated into ``losses`` in some + # pretraining process. + self.calculate_student_loss = calculate_student_loss + # In ``ConfigurableDistller``, the recorder manager is just # constructed, but not really initialized yet. self.distiller.prepare_from_student(self.student) self.distiller.prepare_from_teacher(self.teacher) @property - def student(self): + def student(self) -> nn.Module: """Alias for ``architecture``.""" return self.architecture @@ -86,16 +102,25 @@ def loss( else: with self.distiller.teacher_recorders, self.distiller.deliveries: with torch.no_grad(): - _ = self.teacher(batch_inputs, data_samples, mode='loss') # If the `override_data` of a delivery is True, the delivery will # override the origin data with the recorded data. self.distiller.set_deliveries_override(True) - with self.distiller.student_recorders, self.distiller.deliveries: - student_losses = self.student( - batch_inputs, data_samples, mode='loss') - losses.update(add_prefix(student_losses, 'student')) + # Original task loss will not be used during some pretraining process. + if self.calculate_student_loss: + with self.distiller.student_recorders, self.distiller.deliveries: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) + else: + with self.distiller.student_recorders, self.distiller.deliveries: + if self.student_trainable: + _ = self.student(batch_inputs, data_samples, mode='loss') + else: + with torch.no_grad(): + _ = self.student( + batch_inputs, data_samples, mode='loss') # Automatically compute distill losses based on `loss_forward_mappings` # The required data already exists in the recorders. @@ -104,7 +129,7 @@ def loss( return losses - def train(self, mode=True): + def train(self, mode: bool = True) -> None: """Set distiller's forward mode.""" super().train(mode) if mode and self.teacher_norm_eval: diff --git a/mmrazor/models/architectures/__init__.py b/mmrazor/models/architectures/__init__.py index f267930f3..317e1fde7 100644 --- a/mmrazor/models/architectures/__init__.py +++ b/mmrazor/models/architectures/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backbones import * # noqa: F401,F403 +from .connectors import * # noqa: F401,F403 from .dynamic_op import * # noqa: F401,F403 from .heads import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py new file mode 100644 index 000000000..28673e8ee --- /dev/null +++ b/mmrazor/models/architectures/connectors/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .general_connector import (ConvBNConnector, ConvBNReLUConnector, + SingleConvConnector) + +__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector'] diff --git a/mmrazor/models/architectures/connectors/base_connector.py b/mmrazor/models/architectures/connectors/base_connector.py new file mode 100644 index 000000000..4322efa86 --- /dev/null +++ b/mmrazor/models/architectures/connectors/base_connector.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, Optional + +import torch +from mmcv.runner import BaseModule + + +class BaseConnector(BaseModule, metaclass=ABCMeta): + """Base class of connectors. + + Connector is mainly used for distillation, it usually converts the channel + number of input feature to align features of student and teacher. + + All subclasses should implement the following APIs: + + - ``forward_train()`` + + Args: + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__(self, init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + def forward(self, feature: torch.Tensor) -> None: + """Forward computation. + + Args: + feature (torch.Tensor): Input feature. + """ + return self.forward_train(feature) + + @abstractmethod + def forward_train(self, feature) -> torch.Tensor: + """Abstract train computation. + + Args: + feature (torch.Tensor): Input feature. + """ + pass diff --git a/mmrazor/models/architectures/connectors/general_connector.py b/mmrazor/models/architectures/connectors/general_connector.py new file mode 100644 index 000000000..156468cb8 --- /dev/null +++ b/mmrazor/models/architectures/connectors/general_connector.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmrazor.registry import MODELS +from .base_connector import BaseConnector + + +@MODELS.register_module() +class SingleConvConnector(BaseConnector): + """General connector which only contains a conv layer. + + Args: + in_channel (int): The input channel of the connector. + out_channel (int): The output channel of the connector. + conv_cfg (dict, optional): The config to control the convolution. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + conv_cfg: Optional[Dict] = None, + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__(init_cfg) + self.conv = build_conv_layer( + conv_cfg, in_channel, out_channel, kernel_size=1, stride=1) + + def forward_train(self, feature: torch.Tensor) -> torch.Tensor: + """Forward computation. + + Args: + feature (torch.Tensor): Input feature. + """ + return self.conv(feature) + + def init_weights(self) -> None: + """Init parameters. + + In the subclass of ``BaseModule``, `init_weights` will be called + automativally. + """ + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + device = m.weight.device + in_channels, _, k1, k2 = m.weight.shape + m.weight[:] = torch.randn( + m.weight.shape, device=device) / np.sqrt( + k1 * k2 * in_channels) * 1e-4 + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + continue + + +@MODELS.register_module() +class ConvBNConnector(BaseConnector): + """General connector which contains a conv layer with BN. + + Args: + in_channel (int): The input channels of the connector. + out_channel (int): The output channels of the connector. + norm_cfg (dict): The config to control the normalization. + conv_cfg (dict, optional): The config to control the convolution. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + norm_cfg: Dict, + conv_cfg: Optional[Dict] = None, + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__(init_cfg) + self.conv = build_conv_layer( + conv_cfg, + in_channel, + out_channel, + kernel_size=1, + stride=1, + padding=0, + bias=False) + _, self.bn = build_norm_layer(norm_cfg, out_channel) + + def forward_train(self, feature: torch.Tensor) -> torch.Tensor: + """Forward computation. + + Args: + feature (torch.Tensor): Input feature. + """ + return self.bn(self.conv(feature)) + + +@MODELS.register_module() +class ConvBNReLUConnector(BaseConnector): + """General connector which contains a conv layer with BN and ReLU. + + Args: + in_channel (int): The input channels of the connector. + out_channel (int): The output channels of the connector. + norm_cfg (dict): The config to control the normalization. + conv_cfg (dict, optional): The config to control the convolution. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + norm_cfg: Dict, + conv_cfg: Optional[Dict] = None, + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__(init_cfg) + self.conv = build_conv_layer( + conv_cfg, in_channel, out_channel, kernel_size=1) + _, self.bn = build_norm_layer(norm_cfg, out_channel) + self.relu = nn.ReLU(inplace=True) + + def forward_train(self, feature: torch.Tensor) -> torch.Tensor: + """Forward computation. + + Args: + feature (torch.Tensor): Input feature. + """ + return self.relu(self.bn(self.conv(feature))) diff --git a/mmrazor/models/distillers/base_distiller.py b/mmrazor/models/distillers/base_distiller.py index 317d033f8..4cf575e90 100644 --- a/mmrazor/models/distillers/base_distiller.py +++ b/mmrazor/models/distillers/base_distiller.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import Dict, Optional from mmengine.model import BaseModule @@ -7,9 +8,13 @@ class BaseDistiller(BaseModule, ABC): - """Base class for distiller.""" + """Base class for distiller. - def __init__(self, init_cfg=None): + Args: + init_cfg (dict, optional): Config for distiller. Default to None. + """ + + def __init__(self, init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg) @abstractmethod diff --git a/mmrazor/models/distillers/configurable_distiller.py b/mmrazor/models/distillers/configurable_distiller.py index a794417ed..621dfa624 100644 --- a/mmrazor/models/distillers/configurable_distiller.py +++ b/mmrazor/models/distillers/configurable_distiller.py @@ -38,6 +38,9 @@ class ConfigurableDistiller(BaseDistiller): distill_deliveries (dict, optional): Config for multiple deliveries. A distill algorithm may have more than one delivery. Defaults to None. + connectors (dict, optional): Config for multiple connectors. A + distillation model may have more than one connector. Defaults to + None. distill_losses: (Dict[str, Dict], optional): Config for multiple distill losses. A distill algorithm may have more than one distill loss. Defaults to None. @@ -64,33 +67,45 @@ class ConfigurableDistiller(BaseDistiller): `student_recorders``; otherwise, it means the recorder is in ``teacher_recorders``. + A connector can be called according to its `connector_name`, so that a + input can use a different connector in different loss. + Examples: >>> distill_losses = dict( - ... loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)) + ... loss_neck=dict(type='L2Loss', loss_weight=5)) >>> student_recorders = dict( - ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) + ... feat = dict(type='ModuleOutputs', sources=['neck.gap'])) >>> teacher_recorders = dict( - ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) + ... feat = dict(type='ModuleOutputs', sources=['neck.gap'])) + + >>> connectors = dict( + ... loss_neck_sfeat = dict( + ... type='SingleConvConnector', in_channel=32, out_channel=64), + ... loss_neck_tfeat = dict( + ... type='SingleConvConnector', in_channel=32, out_channel=64)) >>> loss_forward_mappings = dict( - ... loss_kl=dict( - ... preds_S=dict(from_recorder='fc', from_student=True), - ... preds_T=dict(from_recorder='fc', from_student=False))) + ... loss_neck=dict( + ... s_feature=dict(from_recorder='feat', from_student=True, + ... connector='loss_neck_sfeat'), + ... t_feature=dict(from_recorder='feat', from_student=False, + ... connector='loss_neck_tfeat'))) """ def __init__(self, student_recorders: Optional[Dict[str, Dict]] = None, teacher_recorders: Optional[Dict[str, Dict]] = None, distill_deliveries: Optional[Dict[str, Dict]] = None, + connectors: Optional[Dict[str, Dict]] = None, distill_losses: Optional[Dict[str, Dict]] = None, loss_forward_mappings: Optional[Dict[str, Dict]] = None, **kwargs): super().__init__(**kwargs) # The recorder manager is just constructed, but not really initialized # yet. Recorder manager initialization needs to input the corresponding - # model. + # model. self.student_recorders = RecorderManager(student_recorders) self.teacher_recorders = RecorderManager(teacher_recorders) @@ -98,8 +113,10 @@ def __init__(self, self.distill_losses = self.build_distill_losses(distill_losses) + self.connectors = self.build_connectors(connectors) + if loss_forward_mappings: - # Check if loss_forward_mappings is in the correct format + # Check if loss_forward_mappings is in the correct format. self._check_loss_forward_mappings(self.distill_losses, loss_forward_mappings, self.student_recorders, @@ -108,7 +125,7 @@ def __init__(self, else: self.loss_forward_mappings = dict() - def set_deliveries_override(self, override: bool): + def set_deliveries_override(self, override: bool) -> None: """Set the `override_data` of all deliveries.""" self.deliveries.override_data = override @@ -120,6 +137,23 @@ def prepare_from_teacher(self, model: nn.Module) -> None: """Initialize teacher recorders.""" self.teacher_recorders.initialize(model) + def build_connectors( + self, + connectors: Optional[Dict[str, Dict]] = None, + ) -> nn.ModuleDict: + """Initialize connectors.""" + + distill_connecotrs = nn.ModuleDict() + if connectors: + for connector_name, connector_cfg in connectors.items(): + assert connector_name not in distill_connecotrs, \ + f'{connector_name} is already in "distill_connecotrs".' + + connector = MODELS.build(connector_cfg) + distill_connecotrs[connector_name] = connector + + return distill_connecotrs + def build_distill_losses( self, losses: Optional[Dict[str, Dict]] = None, @@ -148,7 +182,8 @@ def get_record(self, recorder: str, from_student: bool, record_idx: int = 0, - data_idx: Optional[int] = None) -> List: + data_idx: Optional[int] = None, + connector: Optional[str] = None) -> List: """According to each item in ``record_infos``, get the corresponding record in ``recorder_manager``.""" @@ -156,8 +191,12 @@ def get_record(self, recorder_ = self.student_recorders.get_recorder(recorder) else: recorder_ = self.teacher_recorders.get_recorder(recorder) + record_data = recorder_.get_record_data(record_idx, data_idx) - return recorder_.get_record_data(record_idx, data_idx) + if connector: + record_data = self.connectors[connector](record_data) + + return record_data def compute_distill_losses(self) -> LossResults: """Compute distill losses automatically.""" @@ -165,8 +204,8 @@ def compute_distill_losses(self) -> LossResults: losses = dict() for loss_name, forward_mappings in self.loss_forward_mappings.items(): forward_kwargs = dict() - for forward_key, record_info in forward_mappings.items(): - forward_var = self.get_record(**record_info) + for forward_key, record in forward_mappings.items(): + forward_var = self.get_record(**record) forward_kwargs[forward_key] = forward_var loss_module = self.distill_losses[loss_name] @@ -233,3 +272,8 @@ def _check_loss_forward_mappings( assert recorder in teacher_recorders.recorders, \ f'For {forward_key}, "{recorder}" must be in \ `teacher_recorders`.' + + if 'connector' in record_info: + connector: str = record_info['connector'] + assert connector in self.connectors, \ + f'{connector} must be in "connectors".' diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 3d3e97e52..db0bd7af6 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cwd import ChannelWiseDivergence from .kl_divergence import KLDivergence +from .l2_loss import L2Loss from .relational_kd import AngleWiseRKD, DistanceWiseRKD from .weighted_soft_label_distillation import WSLD __all__ = [ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', - 'WSLD' + 'WSLD', 'L2Loss' ] diff --git a/mmrazor/models/losses/l2_loss.py b/mmrazor/models/losses/l2_loss.py new file mode 100644 index 000000000..8b373ed38 --- /dev/null +++ b/mmrazor/models/losses/l2_loss.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class L2Loss(nn.Module): + """Calculate the two-norm loss between the two features. + + Args: + loss_weight (float): Weight of loss. Defaults to 1.0. + normalize (bool): Whether to normalize the feature. Defaults to True. + mult (float): Multiplier for feature normalization. Defaults to 1.0. + div_element (bool): Whether to divide the loss by element-wise. + Defaults to False. + """ + + def __init__( + self, + loss_weight: float = 1.0, + normalize: bool = True, + mult: float = 1.0, + div_element: bool = False, + ) -> None: + super().__init__() + self.loss_weight = loss_weight + self.normalize = normalize + self.mult = mult + self.div_element = div_element + + def forward( + self, + s_feature: torch.Tensor, + t_feature: torch.Tensor, + ) -> torch.Tensor: + """Forward computation. + + Args: + s_feature (torch.Tensor): The student model feature with + shape (N, C, H, W) or shape (N, C). + t_feature (torch.Tensor): The teacher model feature with + shape (N, C, H, W) or shape (N, C). + """ + if self.normalize: + s_feature = self.normalize_feature(s_feature) + t_feature = self.normalize_feature(t_feature) + + loss = torch.sum(torch.pow(torch.sub(s_feature, t_feature), 2)) + + if self.div_element: + loss = loss / s_feature.numel() + else: + loss = loss / s_feature.size(0) + + return self.loss_weight * loss + + def normalize_feature(self, feature: torch.Tensor) -> torch.Tensor: + """Normalize the input feature. + + Args: + feature (torch.Tensor): The student model feature with + shape (N, C, H, W) or shape (N, C). + """ + feature = feature.view(feature.size(0), -1) + return feature / feature.norm(2, dim=1, keepdim=True) * self.mult diff --git a/tests/test_models/test_connectors/test_connectors.py b/tests/test_models/test_connectors/test_connectors.py new file mode 100644 index 000000000..d04e14854 --- /dev/null +++ b/tests/test_models/test_connectors/test_connectors.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models import BNConnector, ReLUConnector, SingleConvConnector + + +class TestConnector(TestCase): + + @classmethod + def setUpClass(cls): + cls.s_feat = torch.randn(1, 1, 5, 5) + cls.t_feat = torch.randn(1, 3, 5, 5) + + def test_singleconv_connector(self): + singleconv_connector_cfg = dict(in_channel=1, out_channel=3) + singleconv_connector = SingleConvConnector(**singleconv_connector_cfg) + + output = singleconv_connector.forward_train(self.s_feat) + assert output.size() == self.t_feat.size() + + def test_bn_connector(self): + bn_connector_cfg = dict(in_channel=1, out_channel=3) + bn_connector = BNConnector(**bn_connector_cfg) + + output = bn_connector.forward_train(self.s_feat) + assert output.size() == self.t_feat.size() + + def test_relu_connector(self): + relu_connector_cfg = dict(in_channel=1, out_channel=3) + relu_connector = ReLUConnector(**relu_connector_cfg) + + output = relu_connector.forward_train(self.s_feat) + assert output.size() == self.t_feat.size() diff --git a/tests/test_models/test_losses/test_general_losses.py b/tests/test_models/test_losses/test_general_losses.py new file mode 100644 index 000000000..70b4c2c75 --- /dev/null +++ b/tests/test_models/test_losses/test_general_losses.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models import L2Loss + + +class TestLosses(TestCase): + + @classmethod + def setUpClass(cls): + cls.feats_1d = torch.randn(5, 6) + cls.feats_3d = torch.randn(5, 2, 3, 3) + + def normal_test_1d(self, loss_instance): + loss_1d = loss_instance.forward(self.feats_1d, self.feats_1d) + self.assertTrue(loss_1d.numel() == 1) + + def normal_test_3d(self, loss_instance): + loss_3d = loss_instance.forward(self.feats_3d, self.feats_3d) + self.assertTrue(loss_3d.numel() == 1) + + def test_l2_loss(self): + l2_loss_cfg = dict(loss_weight=10) + l2_loss = L2Loss(**l2_loss_cfg) + self.normal_test_1d(l2_loss) + self.normal_test_3d(l2_loss)