diff --git a/configs/_base_/models/densenet/densenet121.py b/configs/_base_/models/densenet/densenet121.py new file mode 100644 index 00000000000..0a14d302584 --- /dev/null +++ b/configs/_base_/models/densenet/densenet121.py @@ -0,0 +1,11 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='DenseNet', arch='121'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + )) diff --git a/configs/_base_/models/densenet/densenet161.py b/configs/_base_/models/densenet/densenet161.py new file mode 100644 index 00000000000..61a0d838806 --- /dev/null +++ b/configs/_base_/models/densenet/densenet161.py @@ -0,0 +1,11 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='DenseNet', arch='161'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2208, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + )) diff --git a/configs/_base_/models/densenet/densenet169.py b/configs/_base_/models/densenet/densenet169.py new file mode 100644 index 00000000000..779ea170925 --- /dev/null +++ b/configs/_base_/models/densenet/densenet169.py @@ -0,0 +1,11 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='DenseNet', arch='169'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1664, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + )) diff --git a/configs/_base_/models/densenet/densenet201.py b/configs/_base_/models/densenet/densenet201.py new file mode 100644 index 00000000000..2909af0d36c --- /dev/null +++ b/configs/_base_/models/densenet/densenet201.py @@ -0,0 +1,11 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='DenseNet', arch='201'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1920, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + )) diff --git a/configs/densenet/README.md b/configs/densenet/README.md new file mode 100644 index 00000000000..77dfa2987d9 --- /dev/null +++ b/configs/densenet/README.md @@ -0,0 +1,41 @@ +# DenseNet + +> [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) + + +## Abstract + +Recent work has shown that convolutional networks can be substantially deeper, more accurate, and efficient to train if they contain shorter connections between layers close to the input and those close to the output. In this paper, we embrace this observation and introduce the Dense Convolutional Network (DenseNet), which connects each layer to every other layer in a feed-forward fashion. Whereas traditional convolutional networks with L layers have L connections - one between each layer and its subsequent layer - our network has L(L+1)/2 direct connections. For each layer, the feature-maps of all preceding layers are used as inputs, and its own feature-maps are used as inputs into all subsequent layers. DenseNets have several compelling advantages: they alleviate the vanishing-gradient problem, strengthen feature propagation, encourage feature reuse, and substantially reduce the number of parameters. We evaluate our proposed architecture on four highly competitive object recognition benchmark tasks (CIFAR-10, CIFAR-100, SVHN, and ImageNet). DenseNets obtain significant improvements over the state-of-the-art on most of them, whilst requiring less computation to achieve high performance. + +
+ +
+ +## Results and models + +### ImageNet-1k + +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +|:---------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| +| DenseNet121\* | 7.98 | 2.88 | 74.96 | 92.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet121_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet121_4xb256_in1k_20220426-07450f99.pth) | +| DenseNet169\* | 14.15 | 3.42 | 76.08 | 93.11 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet169_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet169_4xb256_in1k_20220426-a2889902.pth) | +| DenseNet201\* | 20.01 | 4.37 | 77.32 | 93.64 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet201_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet201_4xb256_in1k_20220426-05cae4ef.pth) | +| DenseNet161\* | 28.68 | 7.82 | 77.61 | 93.83 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet161_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet161_4xb256_in1k_20220426-ee6a80a9.pth) | + +*Models with \* are converted from [pytorch](https://pytorch.org/vision/stable/models.html), guided by [original repo](https://github.com/liuzhuang13/DenseNet). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* + + +## Citation + +```bibtex +@misc{https://doi.org/10.48550/arxiv.1608.06993, + doi = {10.48550/ARXIV.1608.06993}, + url = {https://arxiv.org/abs/1608.06993}, + author = {Huang, Gao and Liu, Zhuang and van der Maaten, Laurens and Weinberger, Kilian Q.}, + keywords = {Computer Vision and Pattern Recognition (cs.CV), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Densely Connected Convolutional Networks}, + publisher = {arXiv}, + year = {2016}, + copyright = {arXiv.org perpetual, non-exclusive license} +} +``` diff --git a/configs/densenet/densenet121_4xb256_in1k.py b/configs/densenet/densenet121_4xb256_in1k.py new file mode 100644 index 00000000000..08d65ae24ac --- /dev/null +++ b/configs/densenet/densenet121_4xb256_in1k.py @@ -0,0 +1,10 @@ +_base_ = [ + '../_base_/models/densenet/densenet121.py', + '../_base_/datasets/imagenet_bs64.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +data = dict(samples_per_gpu=256) + +runner = dict(type='EpochBasedRunner', max_epochs=90) diff --git a/configs/densenet/densenet161_4xb256_in1k.py b/configs/densenet/densenet161_4xb256_in1k.py new file mode 100644 index 00000000000..4581d1dec69 --- /dev/null +++ b/configs/densenet/densenet161_4xb256_in1k.py @@ -0,0 +1,10 @@ +_base_ = [ + '../_base_/models/densenet/densenet161.py', + '../_base_/datasets/imagenet_bs64.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +data = dict(samples_per_gpu=256) + +runner = dict(type='EpochBasedRunner', max_epochs=90) diff --git a/configs/densenet/densenet169_4xb256_in1k.py b/configs/densenet/densenet169_4xb256_in1k.py new file mode 100644 index 00000000000..6179293beba --- /dev/null +++ b/configs/densenet/densenet169_4xb256_in1k.py @@ -0,0 +1,10 @@ +_base_ = [ + '../_base_/models/densenet/densenet169.py', + '../_base_/datasets/imagenet_bs64.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +data = dict(samples_per_gpu=256) + +runner = dict(type='EpochBasedRunner', max_epochs=90) diff --git a/configs/densenet/densenet201_4xb256_in1k.py b/configs/densenet/densenet201_4xb256_in1k.py new file mode 100644 index 00000000000..897a141dba1 --- /dev/null +++ b/configs/densenet/densenet201_4xb256_in1k.py @@ -0,0 +1,10 @@ +_base_ = [ + '../_base_/models/densenet/densenet201.py', + '../_base_/datasets/imagenet_bs64.py', + '../_base_/schedules/imagenet_bs256.py', + '../_base_/default_runtime.py', +] + +data = dict(samples_per_gpu=256) + +runner = dict(type='EpochBasedRunner', max_epochs=90) diff --git a/configs/densenet/metafile.yml b/configs/densenet/metafile.yml new file mode 100644 index 00000000000..84366b23a35 --- /dev/null +++ b/configs/densenet/metafile.yml @@ -0,0 +1,76 @@ +Collections: + - Name: DenseNet + Metadata: + Training Data: ImageNet-1k + Architecture: + - DenseBlock + Paper: + URL: https://arxiv.org/abs/1608.06993 + Title: Densely Connected Convolutional Networks + README: configs/densenet/README.md + +Models: + - Name: densenet121_4xb256_in1k + Metadata: + FLOPs: 2881695488 + Parameters: 7978856 + In Collections: DenseNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 74.96 + Top 5 Accuracy: 92.21 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/densenet/densenet121_4xb256_in1k_20220426-07450f99.pth + Config: configs/densenet/densenet121_4xb256_in1k.py + Converted From: + Weights: https://download.pytorch.org/models/densenet121-a639ec97.pth + Code: https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py + - Name: densenet169_4xb256_in1k + Metadata: + FLOPs: 3416860160 + Parameters: 14149480 + In Collections: DenseNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.08 + Top 5 Accuracy: 93.11 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/densenet/densenet169_4xb256_in1k_20220426-a2889902.pth + Config: configs/densenet/densenet169_4xb256_in1k.py + Converted From: + Weights: https://download.pytorch.org/models/densenet169-b2777c0a.pth + Code: https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py + - Name: densenet201_4xb256_in1k + Metadata: + FLOPs: 4365236736 + Parameters: 20013928 + In Collections: DenseNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 77.32 + Top 5 Accuracy: 93.64 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/densenet/densenet201_4xb256_in1k_20220426-05cae4ef.pth + Config: configs/densenet/densenet201_4xb256_in1k.py + Converted From: + Weights: https://download.pytorch.org/models/densenet201-c1103571.pth + Code: https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py + - Name: densenet161_4xb256_in1k + Metadata: + FLOPs: 7816363968 + Parameters: 28681000 + In Collections: DenseNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 77.61 + Top 5 Accuracy: 93.83 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/densenet/densenet161_4xb256_in1k_20220426-ee6a80a9.pth + Config: configs/densenet/densenet161_4xb256_in1k.py + Converted From: + Weights: https://download.pytorch.org/models/densenet161-8d451a50.pth + Code: https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 687b8009340..78701a41fe1 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -55,6 +55,7 @@ Backbones Conformer ConvMixer ConvNeXt + DenseNet DistilledVisionTransformer EfficientNet HRNet diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 6451b70add0..d2690b7aac3 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -133,6 +133,10 @@ The ResNet family models below are trained by standard data augmentations, i.e., | CSPDarkNet50\* | 27.64 | 5.04 | 80.05 | 95.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspdarknet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspdarknet50_3rdparty_8xb32_in1k_20220329-bd275287.pth) | | CSPResNet50\* | 21.62 | 3.48 | 79.55 | 94.68 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnet50_3rdparty_8xb32_in1k_20220329-dd6dddfb.pth) | | CSPResNeXt50\* | 20.57 | 3.11 | 79.96 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnext50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnext50_3rdparty_8xb32_in1k_20220329-2cc84d21.pth) | +| DenseNet121\* | 7.98 | 2.88 | 74.96 | 92.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet121_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet121_4xb256_in1k_20220426-07450f99.pth) | +| DenseNet169\* | 14.15 | 3.42 | 76.08 | 93.11 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet169_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet169_4xb256_in1k_20220426-a2889902.pth) | +| DenseNet201\* | 20.01 | 4.37 | 77.32 | 93.64 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet201_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet201_4xb256_in1k_20220426-05cae4ef.pth) | +| DenseNet161\* | 28.68 | 7.82 | 77.61 | 93.83 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/densenet/densenet161_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/densenet/densenet161_4xb256_in1k_20220426-ee6a80a9.pth) | | VAN-T\* | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) | | VAN-S\* | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) | | VAN-B\* | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) | diff --git a/docs/zh_CN/api/models.rst b/docs/zh_CN/api/models.rst index 7b9530246e4..50cb006682a 100644 --- a/docs/zh_CN/api/models.rst +++ b/docs/zh_CN/api/models.rst @@ -55,6 +55,7 @@ Backbones Conformer ConvMixer ConvNeXt + DenseNet DistilledVisionTransformer EfficientNet HRNet diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 13f565f4132..41e72f9d0d6 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -5,6 +5,7 @@ from .convnext import ConvNeXt from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt from .deit import DistilledVisionTransformer +from .densenet import DenseNet from .efficientnet import EfficientNet from .hrnet import HRNet from .lenet import LeNet5 @@ -41,5 +42,5 @@ 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet', - 'PoolFormer', 'VAN' + 'PoolFormer', 'DenseNet', 'VAN' ] diff --git a/mmcls/models/backbones/densenet.py b/mmcls/models/backbones/densenet.py new file mode 100644 index 00000000000..cb2002e4bc2 --- /dev/null +++ b/mmcls/models/backbones/densenet.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import build_activation_layer, build_norm_layer +from torch.jit.annotations import List + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class DenseLayer(BaseBackbone): + """DenseBlock layers.""" + + def __init__(self, + in_channels, + growth_rate, + bn_size, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseLayer, self).__init__() + + self.norm1 = build_norm_layer(norm_cfg, in_channels)[1] + self.conv1 = nn.Conv2d( + in_channels, + bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False) + self.act = build_activation_layer(act_cfg) + self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1] + self.conv2 = nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1( + self.act(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + # This decorator indicates to the compiler that a function or method + # should be ignored and replaced with the raising of an exception. + # Here this function is incompatible with torchscript. + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + # Here use torch.utils.checkpoint to rerun a forward-pass during + # backward in bottleneck to save memories. + return cp.checkpoint(closure, *x) + + def forward(self, x): # noqa: F811 + # type: (List[torch.Tensor]) -> torch.Tensor + # assert input features is a list of Tensor + assert isinstance(x, list) + + if self.memory_efficient and self.any_requires_grad(x): + if torch.jit.is_scripting(): + raise Exception('Memory Efficient not supported in JIT') + bottleneck_output = self.call_checkpoint_bottleneck(x) + else: + bottleneck_output = self.bottleneck_fn(x) + + new_features = self.conv2(self.act(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.Module): + """DenseNet Blocks.""" + + def __init__(self, + num_layers, + in_channels, + bn_size, + growth_rate, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseBlock, self).__init__() + self.block = nn.ModuleList([ + DenseLayer( + in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) for i in range(num_layers) + ]) + + def forward(self, init_features): + features = [init_features] + for layer in self.block: + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + """DenseNet Transition Layers.""" + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')): + super(DenseTransition, self).__init__() + self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1]) + self.add_module('act', build_activation_layer(act_cfg)) + self.add_module( + 'conv', + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, + bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +@BACKBONES.register_module() +class DenseNet(BaseBackbone): + """DenseNet. + + A PyTorch implementation of : `Densely Connected Convolutional Networks + `_ + + Modified from the `official repo + `_ + and `pytorch + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``DenseNet.arch_settings``. And if dict, it + should include the following two keys: + + - growth_rate (int): Each layer of DenseBlock produce `k` feature + maps. Here refers `k` as the growth rate of the network. + - depths (list[int]): Number of repeated layers in each DenseBlock. + - init_channels (int): The output channels of stem layers. + + Defaults to '121'. + in_channels (int): Number of input image channels. Defaults to 3. + bn_size (int): Refers to channel expansion parameter of 1x1 + convolution layer. Defaults to 4. + drop_rate (float): Drop rate of Dropout Layer. Defaults to 0. + compression_factor (float): The reduction rate of transition layers. + Defaults to 0.5. + memory_efficient (bool): If True, uses checkpointing. Much more memory + efficient, but slower. Defaults to False. + See `"paper" `_. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='ReLU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '121': { + 'growth_rate': 32, + 'depths': [6, 12, 24, 16], + 'init_channels': 64, + }, + '169': { + 'growth_rate': 32, + 'depths': [6, 12, 32, 32], + 'init_channels': 64, + }, + '201': { + 'growth_rate': 32, + 'depths': [6, 12, 48, 32], + 'init_channels': 64, + }, + '161': { + 'growth_rate': 48, + 'depths': [6, 12, 36, 24], + 'init_channels': 96, + }, + } + + def __init__(self, + arch='121', + in_channels=3, + bn_size=4, + drop_rate=0, + compression_factor=0.5, + memory_efficient=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'growth_rate', 'depths', 'init_channels'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.growth_rate = arch['growth_rate'] + self.depths = arch['depths'] + self.init_channels = arch['init_channels'] + self.act = build_activation_layer(act_cfg) + + self.num_stages = len(self.depths) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.init_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False), + build_norm_layer(norm_cfg, self.init_channels)[1], self.act, + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Repetitions of DenseNet Blocks + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + + channels = self.init_channels + for i in range(self.num_stages): + depth = self.depths[i] + + stage = DenseBlock( + num_layers=depth, + in_channels=channels, + bn_size=bn_size, + growth_rate=self.growth_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) + self.stages.append(stage) + channels += depth * self.growth_rate + + if i != self.num_stages - 1: + transition = DenseTransition( + in_channels=channels, + out_channels=math.floor(channels * compression_factor), + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + channels = math.floor(channels * compression_factor) + else: + # Final layers after dense block is just bn with act. + # Unlike the paper, the original repo also put this in + # transition layer, whereas torchvision take this out. + # We reckon this as transition layer here. + transition = nn.Sequential( + build_norm_layer(norm_cfg, channels)[1], + self.act, + ) + self.transitions.append(transition) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i in range(self.num_stages): + x = self.stages[i](x) + x = self.transitions[i](x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.transitions[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(DenseNet, self).train(mode) + self._freeze_stages() diff --git a/model-index.yml b/model-index.yml index 81932fd6ac5..f0e0d75ccf3 100644 --- a/model-index.yml +++ b/model-index.yml @@ -25,4 +25,5 @@ Import: - configs/van/metafile.yml - configs/cspnet/metafile.yml - configs/convmixer/metafile.yml + - configs/densenet/metafile.yml - configs/poolformer/metafile.yml diff --git a/tests/test_models/test_backbones/test_densenet.py b/tests/test_models/test_backbones/test_densenet.py new file mode 100644 index 00000000000..5e4c73bc0b4 --- /dev/null +++ b/tests/test_models/test_backbones/test_densenet.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcls.models.backbones import DenseNet + + +def test_assertion(): + with pytest.raises(AssertionError): + DenseNet(arch='unknown') + + with pytest.raises(AssertionError): + # DenseNet arch dict should include essential_keys, + DenseNet(arch=dict(channels=[2, 3, 4, 5])) + + with pytest.raises(AssertionError): + # DenseNet out_indices should be valid depth. + DenseNet(out_indices=-100) + + +def test_DenseNet(): + + # Test forward + model = DenseNet(arch='121') + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 1 + assert feat[0].shape == torch.Size([1, 1024, 7, 7]) + + # Test memory efficient option + model = DenseNet(arch='121', memory_efficient=True) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 1 + assert feat[0].shape == torch.Size([1, 1024, 7, 7]) + + # Test drop rate + model = DenseNet(arch='121', drop_rate=0.05) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 1 + assert feat[0].shape == torch.Size([1, 1024, 7, 7]) + + # Test forward with multiple outputs + model = DenseNet(arch='121', out_indices=(0, 1, 2, 3)) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 128, 28, 28]) + assert feat[1].shape == torch.Size([1, 256, 14, 14]) + assert feat[2].shape == torch.Size([1, 512, 7, 7]) + assert feat[3].shape == torch.Size([1, 1024, 7, 7]) + + # Test with custom arch + model = DenseNet( + arch={ + 'growth_rate': 20, + 'depths': [4, 8, 12, 16, 20], + 'init_channels': 40, + }, + out_indices=(0, 1, 2, 3, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 5 + assert feat[0].shape == torch.Size([1, 60, 28, 28]) + assert feat[1].shape == torch.Size([1, 110, 14, 14]) + assert feat[2].shape == torch.Size([1, 175, 7, 7]) + assert feat[3].shape == torch.Size([1, 247, 3, 3]) + assert feat[4].shape == torch.Size([1, 647, 3, 3]) + + # Test frozen_stages + model = DenseNet(arch='121', out_indices=(0, 1, 2, 3), frozen_stages=2) + model.init_weights() + model.train() + + for i in range(2): + assert not model.stages[i].training + assert not model.transitions[i].training + + for i in range(2, 4): + assert model.stages[i].training + assert model.transitions[i].training