Skip to content

Latest commit

 

History

History
115 lines (82 loc) · 11.2 KB

README.md

File metadata and controls

115 lines (82 loc) · 11.2 KB

KernelWarehouse: Rethinking the Design of Dynamic Convolution

By Chao Li and Anbang Yao.

This repository is an official PyTorch implementation of "KernelWarehouse: Rethinking the Design of Dynamic Convolution", KW for short, which has been accepted to ICML 2024.

Dynamic convolution learns a linear mixture of n static kernels weighted with their input-dependent attentions, demonstrating superior performance than normal convolution. However, it increases the number of convolutional parameters by n times, and thus is not parameter efficient. This leads to no research progress that can allow researchers to explore the setting n>100 (an order of magnitude larger than the typical setting n<10) for pushing forward the performance boundary of dynamic convolution while enjoying parameter efficiency. To fill this gap, in this paper, we propose KernelWarehouse, a more general form of dynamic convolution, which redefines the basic concepts of "kernels", "assembling kernels" and "attention function" through the lens of exploiting convolutional parameter dependencies within the same layer and across neighboring layers of a ConvNet. We testify the effectiveness of KernelWarehouse on ImageNet and MS-COCO datasets using various ConvNet architectures. Intriguingly, KernelWarehouse is also applicable to Vision Transformers, and it can even reduce the model size of a backbone while improving the model accuracy. For instance, KernelWarehouse (n=4) achieves 5.61%|3.90%|4.38% absolute top-1 accuracy gain on the ResNet18|MobileNetV2|DeiT-Tiny backbone, and KernelWarehouse (n=1/4) with 65.10% model size reduction still achieves 2.29% gain on the ResNet18 backbone.

Schematic illustration of KernelWarehouse. Briefly speaking, KernelWarehouse sequentially divides the static kernel $\mathbf{W}$ at any regular convolutional layer of a ConvNet into $m$ disjoint kernel cells $\mathbf{w}_ 1, \dots, \mathbf{w}_ m$ having the same dimensions first, and then computes each kernel cell $\mathbf{w}_ i$ as a linear mixture $\mathbf{w}_ i=\alpha_{i1} \mathbf{e}_ 1+\dots+\alpha_{in}\mathbf{e}_ n$ based on a predefined "warehouse" (consisting of $n$ same dimensioned kernel cells $\mathbf{e}_ 1,\dots,\mathbf{e}_ n$ , e.g., $n=108$) which is shared to all same-stage convolutional layers, and finally replaces the static kernel $\mathbf{W}$ by assembling its corresponding $m$ mixtures in order, yielding a high degree of freedom to fit a desired convolutional parameter budget. The input-dependent scalar attentions $\alpha_{i1},\dots,\alpha_{in}$ are computed with a novel contrasting-driven attention function (CAF).

Dataset

Following this repository,

Requirements

  • python >= 3.7.0
  • torch >= 1.8.1, torchvision >= 0.9.1
  • timm == 0.3.2, tensorboardX, six

Results and Models

Results comparison on the ImageNet validation set with the ResNet18, ResNet50 and ConvNeXt-Tiny backbones trained for 300 epochs.

Models Params Top-1 Acc(%) Top-5 Acc(%) Google Drive Baidu Drive
ResNet18 11.69M 70.44 89.72 model model
+ KW (1/4×) 4.08M 72.73 90.83 model model
+ KW (1/2×) 7.43M 73.33 91.42 model model
+ KW (1×) 11.93M 74.77 92.13 model model
+ KW (2×) 23.24M 75.19 92.18 model model
+ KW (4×) 45.86M 76.05 92.68 model model
ResNet50 25.56M 78.44 94.24 model model
+ KW (1/2×) 17.64M 79.30 94.71 model model
+ KW (1×) 28.05M 80.38 95.19 model model
+ KW (4×) 102.02M 81.05 95.21 model model
ConvNeXt 28.59M 82.07 95.86 model model
+ KW (1×) 39.37M 82.51 96.07 model model

Results comparison on the ImageNet validation set with the MobileNetV2(1.0×, 0.5×) backbones trained for 150 epochs.

Models Params Top-1 Acc(%) Top-5 Acc(%) Google Drive Baidu Drive
MobileNetV2 (1.0×) 3.50M 72.04 90.42 model model
+ KW (1/2×) 2.65M 72.59 90.71 model model
+ KW (1×) 5.17M 74.68 91.90 model model
+ KW (4×) 11.38M 75.92 92.22 model model
MobileNetV2 (0.5×) 1.97M 64.32 85.22 model model
+ KW (1/2×) 1.47M 65.20 85.98 model model
+ KW (1×) 2.85M 68.29 87.93 model model
+ KW (4×) 4.65M 70.26 89.19 model model

Training

To train a model with KernelWarehouse:

python -m torch.distributed.launch --nproc_per_node={number of gpus} main.py --kw_config {path to config json} \
--batch_size {batch size per gpu} --update_freq {number of gradient accumulation steps}  --data_path {path to dataset} \
--output_dir {path to output folder}

For example, to train ResNet18 + KW (1×) on 8 GPUs with batch size of 4096:

python -m torch.distributed.launch --nproc_per_node=8 main.py --kw_config configs/resnet18/kw1x_resnet18.json \
--batch_size 128 --update_freq 4 --data_path {path to dataset} --output_dir {path to output folder}

For example, to train MobileNetV2 + KW (4×) on 8 GPUs with batch size of 256:

python -m torch.distributed.launch --nproc_per_node=8 main.py --kw_config configs/mobilenetv2_100/kw4x_mobilenetv2_100.json \
--batch_size 32 --update_freq 1 --data_path {path to dataset} --output_dir {path to output folder}

You can add "--use_amp true" to enable Automatic Mixed Precision to reduce memory usage and speed up training.

More config files for other models can be found in configs.

Evaluation

To evaluate a pre-trained model:

python -m torch.distributed.launch --nproc_per_node={number of gpus} main.py --kw_config {path to config json} \
--eval true --data_path {path to dataset} --resume {path to model}

Training and evaluation on object detection and instance segmentation

Please refer to README.md in the folder of detection for details.

Citation

If you find our work useful in your research, please consider citing:

@inproceedings{li2024kernelwarehouse,
      title={KernelWarehouse: Rethinking the Design of Dynamic Convolution}, 
      author={Chao Li and Anbang Yao},
      booktitle={International Conference on Machine Learning},
      year={2024}
}

License

KernelWarehouse is released under the Apache license. We encourage use for both research and commercial purposes, as long as proper attribution is given.

Acknowledgment

This repository is built based on ConvNeXt, mmdetection, Dynamic-convolution-Pytorch, Swin-Transformer-Object-Detection repositories. We thank the authors for releasing their amazing codes.