diff --git a/configs/distill/cwd/README.md b/configs/distill/cwd/README.md index 9033221ce..9328790cd 100644 --- a/configs/distill/cwd/README.md +++ b/configs/distill/cwd/README.md @@ -15,12 +15,12 @@ Knowledge distillation (KD) has been proven to be a simple and effective tool fo ### Segmentation |Location|Dataset|Teacher|Student|mIoU|mIoU(T)|mIou(S)|Config | Download | :--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------| -| logits |cityscapes|[pspnet_r101](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py)|[pspnet_r18](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes.py)| 75.54 | 79.76 | 74.87 |[config]()|[teacher](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth) |[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k_mIoU-75.54_20211222-3a26ee1c.pth?versionId=CAEQHxiBgMCPxIKJ7xciIGU1N2JhYzgzYWE0YTRhYmRiZjVmMTA3MTA3NDk1ZWNl) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k_20211212_205711.log.json?versionId=CAEQHxiBgMDZ_oOJ7xciIDJjYzIxYTYyODYzMzQzNDk5Mjg1NTIwMWFkODliMGFk)| +| logits |cityscapes|[pspnet_r101](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py)|[pspnet_r18](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes.py)| 75.54 | 79.76 | 74.87 |[config]()|[teacher](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth) |[model](https://download.openmmlab.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k_mIoU-75.54_20211222-3a26ee1c.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k_20211212_205711.log.json?)| ### Detection |Location|Dataset|Teacher|Student|mAP|mAP(T)|mAP(S)|Config | Download | :--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------| -| cls head |COCO|[gfl_r101_2x](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py)|[gfl_r50_1x](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl/gfl_r50_fpn_1x_coco.py)| 41.9 | 44.7 | 40.2 |[config]()|[teacher](https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth) |[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211222-655dff39.pth?versionId=CAEQHxiBgMD7.uuI7xciIDY1MDRjYzlkN2ExOTRiY2NhNmU4NGJlMmExNjA2YzMy) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211212_205444.log.json?versionId=CAEQHxiBgID.o_WI7xciIDgyZjRjYTU4Y2ZjNjRjOGU5MTBlMTQ3ZjEyMTE4OTJl)| +| cls head |COCO|[gfl_r101_2x](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py)|[gfl_r50_1x](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl/gfl_r50_fpn_1x_coco.py)| 41.9 | 44.7 | 40.2 |[config]()|[teacher](https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth) |[model](https://download.openmmlab.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211222-655dff39.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211212_205444.log.json)| ## Citation diff --git a/configs/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py b/configs/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py index 272fd3c91..e55b733cf 100644 --- a/configs/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py +++ b/configs/distill/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py @@ -57,8 +57,11 @@ nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) +checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa: E501 + teacher = dict( type='mmdet.GFL', + init_cfg=dict(type='Pretrained', checkpoint=checkpoint), backbone=dict( type='ResNet', depth=101, diff --git a/configs/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py b/configs/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py index 0ff436c83..c3fd00b12 100644 --- a/configs/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py +++ b/configs/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py @@ -50,9 +50,12 @@ train_cfg=dict(), test_cfg=dict(mode='whole')) +checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth' # noqa: E501 + # pspnet r101 teacher = dict( type='mmseg.EncoderDecoder', + init_cfg=dict(type='Pretrained', checkpoint=checkpoint), backbone=dict( type='ResNetV1c', depth=101, diff --git a/configs/distill/wsld/README.md b/configs/distill/wsld/README.md index 10f259b3b..8c9e78100 100644 --- a/configs/distill/wsld/README.md +++ b/configs/distill/wsld/README.md @@ -27,7 +27,7 @@ effectiveness of our method. ### Classification |Location|Dataset|Teacher|Student|Acc|Acc(T)|Acc(S)|Config | Download | :--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------| -| cls head |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 71.54 | 73.62 | 69.90 |[config](./wsld_cls_head_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) |[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth?versionId=CAEQHxiBgMC6memK7xciIGMzMDFlYTA4YzhlYTRiMTNiZWU0YTVhY2I5NjVkMjY2) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_20211221_181516.log.json?versionId=CAEQHxiBgIDLmemK7xciIGNkM2FiN2Y4N2E5YjRhNDE4NDVlNmExNDczZDIxN2E5)| +| cls head |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 71.54 | 73.62 | 69.90 |[config](./wsld_cls_head_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) |[model](https://download.openmmlab.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_20211221_181516.log.json)| diff --git a/configs/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k.py b/configs/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k.py index be4ec7246..06ac8c328 100644 --- a/configs/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k.py +++ b/configs/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k.py @@ -22,9 +22,12 @@ topk=(1, 5), )) +checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + # teacher settings teacher = dict( type='mmcls.ImageClassifier', + init_cfg=dict(type='Pretrained', checkpoint=checkpoint), backbone=dict( type='ResNet', depth=34, diff --git a/configs/nas/darts/README.md b/configs/nas/darts/README.md index 3eb4d38d3..bf28d670f 100644 --- a/configs/nas/darts/README.md +++ b/configs/nas/darts/README.md @@ -22,8 +22,8 @@ Dataset|Unroll|Config|Download| Dataset|Params(M)|Flops(G)|Top-1 Acc|Top-5 Acc|Subnet|Config|Download|Remarks| |:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:------:|:------:| -|Cifar10|3.42 | 0.48 | 97.32 |99.94|[mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_mutable_cfg.yaml?versionId=CAEQHxiBgMDn0ICL7xciIDAwNzUzZTU3ZjE4OTQ0MDg5YmZiMmYzYzExZTQ3YTRm)|[config](./darts_subnetnet_1xb96_cifar10.py)| [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth?versionId=CAEQHxiBgID20ICL7xciIDllOWZmNTliMzkwNzQ5YzdhODk2MzY1MWEyOTQ1Yjlk) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_20211222-e5727921.log.json?versionId=CAEQHxiBgMDz0ICL7xciIGRhMjk0NDU0OTVhZjQwMDg4N2ZkMDAzZDM1ZWU4N2Ri)|MMRazor searched -|Cifar10|3.83 | 0.55 | 97.27 |99.98|[mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_mutable_cfg.yaml?versionId=CAEQHxiBgICrnpmL7xciIGFmYzUxYjdmYWM1YzQ3N2I5NGU1MDE2ZjIxYmJhY2E0)|[config](./darts_subnetnet_1xb96_cifar10.py)| [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600.pth?versionId=CAEQHxiBgIDQnpmL7xciIGQzOTRkMTViMDgzNzQ2MWI5MmUyNzIxZDk4OTUzZDgz) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_20211222-17e42600.log.json?versionId=CAEQHxiBgMDPnpmL7xciIDViYTVlYTIyYmQ2OTQ1ZDZhNTNhMjVkODA2NDRlMTI1)|official +|Cifar10|3.42 | 0.48 | 97.32 |99.94|[mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_mutable_cfg.yaml)|[config](./darts_subnetnet_1xb96_cifar10.py)| [model](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_20211222-e5727921.log.json)|MMRazor searched +|Cifar10|3.83 | 0.55 | 97.27 |99.98|[mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_mutable_cfg.yaml)|[config](./darts_subnetnet_1xb96_cifar10.py)| [model](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_20211222-17e42600.log.json)|official ## Citation diff --git a/configs/nas/darts/darts_subnetnet_1xb96_cifar10.py b/configs/nas/darts/darts_subnet_1xb96_cifar10.py similarity index 90% rename from configs/nas/darts/darts_subnetnet_1xb96_cifar10.py rename to configs/nas/darts/darts_subnet_1xb96_cifar10.py index 8134ad59f..1e979578f 100644 --- a/configs/nas/darts/darts_subnetnet_1xb96_cifar10.py +++ b/configs/nas/darts/darts_subnet_1xb96_cifar10.py @@ -28,6 +28,9 @@ cal_acc=True), ) +# FIXME: you may replace this with the mutable_cfg searched by yourself +mutable_cfg = 'https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_mutable_cfg.yaml' # noqa: E501 + algorithm = dict( type='Darts', architecture=dict(type='MMClsArchitecture', model=model), @@ -69,7 +72,8 @@ )), ), retraining=True, - unroll=False) + unroll=False, + mutable_cfg=mutable_cfg) data = dict(workers_per_gpu=8) diff --git a/configs/nas/detnas/README.md b/configs/nas/detnas/README.md index 60fb57a91..4a66866c2 100644 --- a/configs/nas/detnas/README.md +++ b/configs/nas/detnas/README.md @@ -38,7 +38,7 @@ python ./tools/mmdet/search_mmdet.py \ python ./tools/mmcls/train_mmcls.py \ configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k.py \ --work-dir $WORK_DIR \ - --cfg-options algorithm.mutable_cfg=$STEP3_SUBNET_YAML + --cfg-options algorithm.mutable_cfg=$STEP3_SUBNET_YAML # or modify the config directly ``` ### Step 5: Subnet fine-tuning on COCO @@ -46,13 +46,13 @@ python ./tools/mmcls/train_mmcls.py \ python ./tools/mmdet/train_mmdet.py \ configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py \ --work-dir $WORK_DIR \ - --cfg-options algorithm.mutable_cfg=$STEP3_SUBNET_YAML load_from=$STEP4_CKPT + --cfg-options algorithm.mutable_cfg=$STEP3_SUBNET_YAML load_from=$STEP4_CKPT # or modify the config directly ``` ## Results and models |Dataset| Supernet | Subnet |Params(M)| Flops(G) | mAP | Config | Download | Remarks| |:---------------:|:---------------:|:-----------:|:-----------:|:-----------:|:--------------:|:------:|:--------:|:--------:| -|COCO| FRCNN-ShuffleNetV2| [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml?versionId=CAEQHxiBgMDU3taI7xciIDUzMmM4MTg4YTgwZDRhYjY4NjA3M2NkZDA0NWExNmY1) | 3.35(backbone)|0.34(backbone) | 37.5 |[config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py)|[pretrain](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth?versionId=CAEQHxiBgICBxuuL7xciIGEyNzZkZmRmZmM5NzRjNDViOTNjOWZkNjk0OWYyYTdm) |[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth?versionId=CAEQHxiBgIDd3taI7xciIDIxYmUzMDE4ZmZmMjQ4ZGNiNzI1YjcxOGM4OGM5NDZl) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json?versionId=CAEQHxiBgMCSq9mM7xciIDViODRmMDE1Yjk1MDQwMTViMDBmYzZlMjg0OTJjYTlh)|MMRazor searched +|COCO| FRCNN-ShuffleNetV2| [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 3.35(backbone)|0.34(backbone) | 37.5 |[config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py)|[pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) |[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json)|MMRazor searched **Note**: diff --git a/configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py b/configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py index be40b5535..dc929cc88 100644 --- a/configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py +++ b/configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py @@ -1,3 +1,6 @@ _base_ = ['./detnas_supernet_frcnn_shufflenetv2_fpn_1x_coco.py'] -algorithm = dict(retraining=True) +# FIXME: you may replace this with the mutable_cfg searched by yourself +mutable_cfg = 'https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml' # noqa: E501 + +algorithm = dict(retraining=True, mutable_cfg=mutable_cfg) diff --git a/configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k.py b/configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k.py index 3e373ec08..9486cba6f 100644 --- a/configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k.py +++ b/configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k.py @@ -1,3 +1,8 @@ _base_ = [ '../spos/spos_subnet_shufflenetv2_8xb128_in1k.py', ] + +# FIXME: you may replace this with the mutable_cfg searched by yourself +mutable_cfg = 'https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml' # noqa: E501 + +algorithm = dict(mutable_cfg=mutable_cfg) diff --git a/configs/nas/spos/README.md b/configs/nas/spos/README.md index 68075cf88..a21272fdb 100644 --- a/configs/nas/spos/README.md +++ b/configs/nas/spos/README.md @@ -36,14 +36,14 @@ python ./tools/mmcls/search_mmcls.py \ python ./tools/mmcls/train_mmcls.py \ configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py \ --work-dir $WORK_DIR \ - --cfg-options algorithm.mutable_cfg=$STEP2_SUBNET_YAML + --cfg-options algorithm.mutable_cfg=$STEP2_SUBNET_YAML # or modify the config directly ``` ## Results and models | Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks | | :------: |:----------------------:| :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------:| -| ImageNet | ShuffleNetV2 | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 3.35 | 0.33 | 73.87 | 91.6 | [config](./spos_subnet_shufflenetv2_8xb128_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth?versionId=CAEQHxiBgIDK5b6I7xciIDM1YjIwZjQxN2UyMDRjYjA5YTM5NTBlMGNhMTdkNjI2) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json?versionId=CAEQHxiBgIDr9cuL7xciIDBmOTZiZGUyYjRiMDQ5NzhhZjY0NWUxYmUzNDlmNTg5) | MMRazor searched | +| ImageNet | ShuffleNetV2 | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml) | 3.35 | 0.33 | 73.87 | 91.6 | [config](./spos_subnet_shufflenetv2_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json) | MMRazor searched | | ImageNet | MobileNet-ProxylessGPU | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_mobilenet_subnet/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_mutable_cfg.yaml) | 5.94 | 0.49* | 75.98 | 92.77 | [config](./spos_mobilenet_for_check_ckpt_from_anglenas.py) | | [AngleNAS](https://github.com/megvii-model/AngleNAS) searched | **Note**: diff --git a/configs/nas/spos/spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py b/configs/nas/spos/spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py index 198a9c053..54d5ff5fb 100644 --- a/configs/nas/spos/spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py +++ b/configs/nas/spos/spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py @@ -2,7 +2,10 @@ './spos_supernet_mobilenet_proxyless_gpu_8xb128_in1k.py', ] -algorithm = dict(retraining=True) +# FIXME: you may replace this with the mutable_cfg searched by yourself +mutable_cfg = 'https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_mobilenet_subnet/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_mutable_cfg.yaml' # noqa: E501 + +algorithm = dict(retraining=True, mutable_cfg=mutable_cfg) evaluation = dict(interval=10000, metric='accuracy') checkpoint_config = dict(interval=30000) diff --git a/configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py b/configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py index e849579fe..110ee047b 100644 --- a/configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py +++ b/configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py @@ -2,7 +2,10 @@ './spos_supernet_shufflenetv2_8xb128_in1k.py', ] -algorithm = dict(retraining=True) +# FIXME: you may replace this with the mutable_cfg searched by yourself +mutable_cfg = 'https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml' # noqa: E501 + +algorithm = dict(retraining=True, mutable_cfg=mutable_cfg) runner = dict(max_iters=300000) find_unused_parameters = False diff --git a/configs/pruning/autoslim/README.md b/configs/pruning/autoslim/README.md index 84d3b7e5f..fdd15d073 100644 --- a/configs/pruning/autoslim/README.md +++ b/configs/pruning/autoslim/README.md @@ -48,15 +48,15 @@ python ./tools/model_converters/split_checkpoint.py \ python ./tools/mmcls/test_mmcls.py \ configs/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k.py \ your_splitted_checkpoint_path --metrics accuracy \ - --cfg-options algorithm.channel_cfg=configs/pruning/autoslim/AUTOSLIM_MBV2_530M_OFFICIAL.yaml + --cfg-options algorithm.channel_cfg=configs/pruning/autoslim/AUTOSLIM_MBV2_530M_OFFICIAL.yaml # or modify the config directly ## Results and models ### Subnet retrain | Supernet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Subnet | Remark | | :----------------- | :-------: | -------: | :-------: | :-------: | :----: | :------: | :-------------: | :----: | -| MobileNet v2(x1.5) | 6.5 | 0.53 | 74.23 | 91.74 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd.pth?versionId=CAEQHxiBgICYsIaI7xciIDE1MGIxM2Q5NDk1NjRlOTFiMjgwOTRmYzJlMDBmZDY0) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json?versionId=CAEQHxiBgMCjj9SL7xciIDFmYmM4NTExZmIzNjRmNmQ4MmMyZWI4YzJmMmM2MDdl) | [channel](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd_channel_cfg.yaml?versionId=CAEQHxiBgMDwr4aI7xciIDQ2MmRhMDFhNGMyODQyYmU5ZTIyOTcxMmRlN2RmYjg2) | official channel cfg | -| MobileNet v2(x1.5) | 5.77 | 0.32 | 72.73 | 90.83 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c.pth?versionId=CAEQHxiBgMCasIaI7xciIDEzN2FkZjZkNWMwYjRiOTg5NTY0MzY0ODk5ODE2N2Yz) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json?versionId=CAEQHxiBgMCjj9SL7xciIDFmYmM4NTExZmIzNjRmNmQ4MmMyZWI4YzJmMmM2MDdl) | [channel](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c_channel_cfg.yaml?versionId=CAEQHxiCgMDwr4aI7xciIDhjMmUzZjlmZTJjODQzMDRhMmQxMzkyM2MwOTZhNjE3) | official channel cfg | -| MobileNet v2(x1.5) | 4.13 |0.22 | 71.39 | 90.08 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b.pth?versionId=CAEQHxiBgICRsIaI7xciIDVlY2MxMTkwZjg0ODQ3M2I5NTJmYjFiNDk1MDEwNjAy) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json?versionId=CAEQHxiBgMCjj9SL7xciIDFmYmM4NTExZmIzNjRmNmQ4MmMyZWI4YzJmMmM2MDdl) | [channel](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b_channel_cfg.yaml.?versionId=CAEQHxiBgIDzr4aI7xciIDViNGY0ZDA1ODkxZTRkMGFhNTg2M2FlZmQyZTFiMDgx) | official channel cfg | +| MobileNet v2(x1.5) | 6.5 | 0.53 | 74.23 | 91.74 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd_channel_cfg.yaml) | official channel cfg | +| MobileNet v2(x1.5) | 5.77 | 0.32 | 72.73 | 90.83 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c_channel_cfg.yaml) | official channel cfg | +| MobileNet v2(x1.5) | 4.13 |0.22 | 71.39 | 90.08 | [config](./autoslim_mbv2_subnet_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b.pth) | [log](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1kautoslim_mbv2_subnet_8xb256_in1k_paper_channel_cfg.log.json) | [channel](https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b_channel_cfg.yaml) | official channel cfg | Note that we ran the official code and the Top-1 Acc of the models with official channel cfg are 73.8%, 72.5% and 71.1%. And there are 3 differences between our diff --git a/configs/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k.py b/configs/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k.py index afc974eb3..00d2f841a 100644 --- a/configs/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k.py +++ b/configs/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k.py @@ -10,12 +10,19 @@ label_smooth_val=0.1, loss_weight=1.0))) +# FIXME: you may replace this with the channel_cfg searched by yourself +channel_cfg = [ + 'https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.53M_acc-74.23_20211222-e5208bbd_channel_cfg.yaml', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.32M_acc-72.73_20211222-b5b0b33c_channel_cfg.yaml', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v0.1/pruning/autoslim/autoslim_mbv2_subnet_8xb256_in1k/autoslim_mbv2_subnet_8xb256_in1k_flops-0.22M_acc-71.39_20211222-43117c7b_channel_cfg.yaml' # noqa: E501 +] + algorithm = dict( architecture=dict(type='MMClsArchitecture', model=model), distiller=None, retraining=True, bn_training_mode=False, -) + channel_cfg=channel_cfg) runner = dict(type='EpochBasedRunner', max_epochs=300) diff --git a/docs/en/conf.py b/docs/en/conf.py index 8b4db443a..8e69fcc18 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -41,12 +41,8 @@ def get_version(): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'recommonmark', - 'sphinx_markdown_tables', - 'sphinx_copybutton', + 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', + 'sphinx_markdown_tables', 'sphinx_copybutton', 'myst_parser' ] autodoc_mock_imports = [ diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 9c947ddca..e2ad66278 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -4,24 +4,24 @@ ### CWD -Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd/README.md) for details. +Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd) for details. ### WSLD -Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld/README.md) for details. +Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld) for details. ### DARTS -Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts/README.md) for details. +Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts) for details. ### DETNAS -Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas/README.md) for details. +Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas) for details. ### SPOS -Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md) for details. +Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos) for details. ### AUTOSLIM -Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md) for details. +Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim) for details. diff --git a/docs/en/train.md b/docs/en/train.md index b186e916c..ad93b8892 100644 --- a/docs/en/train.md +++ b/docs/en/train.md @@ -59,6 +59,13 @@ python ./tools/mmcls/train_mmcls.py \ --cfg-options algorithm.mutable_cfg=configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml +We note that instead of using ``--cfg-options``, you can also directly modify ``configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py`` like this: + +
+mutable_cfg = 'configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml'
+algorithm = dict(..., mutable_cfg=mutable_cfg)
+
+ ## Pruning Pruning has three steps, including **supernet pre-training**, **search for subnet on the trained supernet** and **subnet retraining**. The commands of the first two steps are similar to NAS, except that we need to use `CONFIG_FILE` of Pruning here. The commands of the **subnet retraining** are as follows. @@ -95,7 +102,7 @@ python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.dis For example,
-python ./tools/mmdet/train_mmdet.py \
+python ./tools/mmseg/train_mmseg.py \
   configs/distill/cwd/cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py \
   --work-dir your_work_dir \
   --cfg-options algorithm.distiller.teacher.init_cfg.type=Pretrained algorithm.distiller.teacher.init_cfg.checkpoint=https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth
diff --git a/mmrazor/apis/mmcls/train.py b/mmrazor/apis/mmcls/train.py
index 91aafc4bf..89e2a5c50 100644
--- a/mmrazor/apis/mmcls/train.py
+++ b/mmrazor/apis/mmcls/train.py
@@ -65,7 +65,24 @@ def train_mmcls_model(model,
         train_dataset = dataset[0]
         dataset[0] = split_dataset(train_dataset)
 
-    sampler_cfg = cfg.data.get('sampler', None)
+    loader_cfg = dict(
+        # cfg.gpus will be ignored if distributed
+        num_gpus=len(cfg.gpu_ids),
+        dist=distributed,
+        round_up=True,
+        seed=cfg.get('seed'),
+        sampler_cfg=cfg.get('sampler', None),
+    )
+    # The overall dataloader settings
+    loader_cfg.update({
+        k: v
+        for k, v in cfg.data.items() if k not in [
+            'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
+            'test_dataloader'
+        ]
+    })
+    # The specific dataloader settings
+    train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
 
     # Difference from mmclassification.
     # Build multi dataloaders according the splited datasets.
@@ -73,28 +90,11 @@ def train_mmcls_model(model,
     for dset in dataset:
         if isinstance(dset, list):
             data_loader = [
-                build_dataloader(
-                    item_ds,
-                    cfg.data.samples_per_gpu,
-                    cfg.data.workers_per_gpu,
-                    # cfg.gpus will be ignored if distributed
-                    num_gpus=len(cfg.gpu_ids),
-                    dist=distributed,
-                    round_up=True,
-                    seed=cfg.seed,
-                    sampler_cfg=sampler_cfg) for item_ds in dset
+                build_dataloader(item_ds, **train_loader_cfg)
+                for item_ds in dset
             ]
         else:
-            data_loader = build_dataloader(
-                dset,
-                cfg.data.samples_per_gpu,
-                cfg.data.workers_per_gpu,
-                # cfg.gpus will be ignored if distributed
-                num_gpus=len(cfg.gpu_ids),
-                dist=distributed,
-                round_up=True,
-                seed=cfg.seed,
-                sampler_cfg=sampler_cfg)
+            data_loader = build_dataloader(dset, **train_loader_cfg)
 
         data_loaders.append(data_loader)
 
@@ -188,13 +188,13 @@ def train_mmcls_model(model,
     # register eval hooks
     if validate:
         val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
-        val_dataloader = build_dataloader(
-            val_dataset,
-            samples_per_gpu=cfg.data.samples_per_gpu,
-            workers_per_gpu=cfg.data.workers_per_gpu,
-            dist=distributed,
-            shuffle=False,
-            round_up=True)
+        val_loader_cfg = {
+            **loader_cfg,
+            'shuffle': False,  # Not shuffle by default
+            'sampler_cfg': None,  # Not use sampler by default
+            **cfg.data.get('val_dataloader', {}),
+        }
+        val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
         eval_cfg = cfg.get('evaluation', {})
 
         eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
diff --git a/mmrazor/apis/mmseg/train.py b/mmrazor/apis/mmseg/train.py
index 16de892f9..ed09df78e 100644
--- a/mmrazor/apis/mmseg/train.py
+++ b/mmrazor/apis/mmseg/train.py
@@ -57,17 +57,25 @@ def train_mmseg_model(model,
 
     # prepare data loaders
     dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
-    data_loaders = [
-        build_dataloader(
-            ds,
-            cfg.data.samples_per_gpu,
-            cfg.data.workers_per_gpu,
-            # cfg.gpus will be ignored if distributed
-            len(cfg.gpu_ids),
-            dist=distributed,
-            seed=cfg.seed,
-            drop_last=True) for ds in dataset
-    ]
+    # The default loader config
+    loader_cfg = dict(
+        # cfg.gpus will be ignored if distributed
+        num_gpus=len(cfg.gpu_ids),
+        dist=distributed,
+        seed=cfg.seed,
+        drop_last=True)
+    # The overall dataloader settings
+    loader_cfg.update({
+        k: v
+        for k, v in cfg.data.items() if k not in [
+            'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
+            'test_dataloader'
+        ]
+    })
+
+    # The specific dataloader settings
+    train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
+    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
 
     # put model on gpus
     if distributed:
@@ -130,12 +138,14 @@ def train_mmseg_model(model,
     # register eval hooks
     if validate:
         val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
-        val_dataloader = build_dataloader(
-            val_dataset,
-            samples_per_gpu=1,
-            workers_per_gpu=cfg.data.workers_per_gpu,
-            dist=distributed,
-            shuffle=False)
+        # The specific dataloader settings
+        val_loader_cfg = {
+            **loader_cfg,
+            'samples_per_gpu': 1,
+            'shuffle': False,  # Not shuffle by default
+            **cfg.data.get('val_dataloader', {}),
+        }
+        val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
         eval_cfg = cfg.get('evaluation', {})
         eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
         eval_hook = DistEvalHook if distributed else EvalHook
diff --git a/mmrazor/core/searcher/evolution_search.py b/mmrazor/core/searcher/evolution_search.py
index 43cfe558a..a4fdee41c 100644
--- a/mmrazor/core/searcher/evolution_search.py
+++ b/mmrazor/core/searcher/evolution_search.py
@@ -135,14 +135,11 @@ def search(self):
 
                     if self.check_constraints():
                         self.candidate_pool.append(candidate)
-
-                broadcast_candidate_pool = self.candidate_pool
             else:
-                broadcast_candidate_pool = [None] * self.candidate_pool_size
-            broadcast_candidate_pool = broadcast_object_list(
-                broadcast_candidate_pool)
+                self.candidate_pool = [None] * self.candidate_pool_size
+            broadcast_object_list(self.candidate_pool)
 
-            for i, candidate in enumerate(broadcast_candidate_pool):
+            for i, candidate in enumerate(self.candidate_pool):
                 self.algorithm.mutator.set_subnet(candidate)
                 outputs = self.test_fn(self.algorithm_for_test,
                                        self.dataloader)
@@ -213,7 +210,7 @@ def search(self):
                 self.logger.info(
                     f'Epoch:[{epoch + 1}/{self.max_epoch}], top1_score: '
                     f'{list(self.top_k_candidates_with_score.keys())[0]}')
-            self.candidate_pool = broadcast_object_list(self.candidate_pool)
+            broadcast_object_list(self.candidate_pool)
 
         if rank == 0:
             final_subnet_dict = list(
diff --git a/mmrazor/core/searcher/greedy_search.py b/mmrazor/core/searcher/greedy_search.py
index 2d1410389..215d0fffe 100644
--- a/mmrazor/core/searcher/greedy_search.py
+++ b/mmrazor/core/searcher/greedy_search.py
@@ -146,7 +146,7 @@ def search(self):
 
                     # Broadcasts scores in broadcast_scores to the whole
                     # group.
-                    broadcast_scores = broadcast_object_list(broadcast_scores)
+                    broadcast_object_list(broadcast_scores)
                     score = broadcast_scores[0]
                     self.logger.info(
                         f'Slimming group {name}, {self.score_key}: {score}')
diff --git a/mmrazor/core/utils/__init__.py b/mmrazor/core/utils/__init__.py
index 415b5418b..9267327cc 100644
--- a/mmrazor/core/utils/__init__.py
+++ b/mmrazor/core/utils/__init__.py
@@ -1,5 +1,9 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .broadcast import broadcast_object_list
 from .lr import set_lr
+from .utils import get_backend, get_default_group, get_rank, get_world_size
 
-__all__ = ['broadcast_object_list', 'set_lr']
+__all__ = [
+    'broadcast_object_list', 'set_lr', 'get_world_size', 'get_rank',
+    'get_backend', 'get_default_group'
+]
diff --git a/mmrazor/core/utils/broadcast.py b/mmrazor/core/utils/broadcast.py
index 6326311c0..41c2998b4 100644
--- a/mmrazor/core/utils/broadcast.py
+++ b/mmrazor/core/utils/broadcast.py
@@ -1,49 +1,155 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
-import shutil
-import tempfile
+import pickle
+import warnings
+from typing import Any, List, Optional, Tuple
 
-import mmcv.fileio
 import torch
-import torch.distributed as dist
-from mmcv.runner import get_dist_info
+from mmcv.utils import TORCH_VERSION, digit_version
+from torch import Tensor
+from torch import distributed as dist
 
+from .utils import get_backend, get_default_group, get_rank, get_world_size
 
-def broadcast_object_list(object_list, src=0):
-    """Broadcasts picklable objects in ``object_list`` to the whole group.
 
-    Note that all objects in ``object_list`` must be picklable in order to be
-    broadcasted.
+def _object_to_tensor(obj: Any) -> Tuple[Tensor, Tensor]:
+    """Serialize picklable python object to tensor."""
+    byte_storage = torch.ByteStorage.from_buffer(pickle.dumps(obj))
+    # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor
+    # and specifying dtype. Otherwise, it will cause 100X slowdown.
+    # See: https://github.com/pytorch/pytorch/issues/65696
+    byte_tensor = torch.ByteTensor(byte_storage)
+    local_size = torch.LongTensor([byte_tensor.numel()])
+    return byte_tensor, local_size
 
-    Args:
-        object_list (List[Any]): List of input objects to broadcast.
-            Each object must be picklable. Only objects on the src rank will be
-            broadcast, but each rank must provide lists of equal sizes.
-        src (int): Source rank from which to broadcast ``object_list``.
+
+def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any:
+    """Deserialize tensor to picklable python object."""
+    buf = tensor.cpu().numpy().tobytes()[:tensor_size]
+    return pickle.loads(buf)
+
+
+def _broadcast_object_list(object_list: List[Any],
+                           src: int = 0,
+                           group: Optional[dist.ProcessGroup] = None) -> None:
+    """Broadcast picklable objects in ``object_list`` to the whole group.
+
+    Similar to :func:`broadcast`, but Python objects can be passed in. Note
+    that all objects in ``object_list`` must be picklable in order to be
+    broadcasted.
     """
-    my_rank, _ = get_dist_info()
+    if dist.distributed_c10d._rank_not_in_group(group):
+        return
 
-    MAX_LEN = 512
-    # 32 is whitespace
-    dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda')
-    object_list_return = list()
+    my_rank = get_rank()
+    # Serialize object_list elements to tensors on src rank.
     if my_rank == src:
-        mmcv.mkdir_or_exist('.dist_broadcast')
-        tmpdir = tempfile.mkdtemp(dir='.dist_broadcast')
-        mmcv.dump(object_list, osp.join(tmpdir, 'object_list.pkl'))
-        tmpdir = torch.tensor(
-            bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
-        dir_tensor[:len(tmpdir)] = tmpdir
+        tensor_list, size_list = zip(
+            *[_object_to_tensor(obj) for obj in object_list])
+        object_sizes_tensor = torch.cat(size_list)
+    else:
+        object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
 
-    dist.broadcast(dir_tensor, src)
-    tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+    # Current device selection.
+    # To preserve backwards compatibility, ``device`` is ``None`` by default.
+    # in which case we run current logic of device selection, i.e.
+    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In
+    # the case it is not ``None`` we move the size and object tensors to be
+    # broadcasted to this device.
+    group_backend = get_backend(group)
+    is_nccl_backend = group_backend == dist.Backend.NCCL
+    current_device = torch.device('cpu')
+    if is_nccl_backend:
+        # See note about using torch.cuda.current_device() here in
+        # docstring. We cannot simply use my_rank since rank == device is
+        # not necessarily true.
+        current_device = torch.device('cuda', torch.cuda.current_device())
+        object_sizes_tensor = object_sizes_tensor.to(current_device)
 
-    if my_rank != src:
-        object_list_return = mmcv.load(osp.join(tmpdir, 'object_list.pkl'))
+    # Broadcast object sizes
+    dist.broadcast(object_sizes_tensor, src=src, group=group)
 
-    dist.barrier()
+    # Concatenate and broadcast serialized object tensors
     if my_rank == src:
-        shutil.rmtree(tmpdir)
-        object_list_return = object_list
+        object_tensor = torch.cat(tensor_list)
+    else:
+        object_tensor = torch.empty(
+            torch.sum(object_sizes_tensor).int().item(),
+            dtype=torch.uint8,
+        )
+
+    if is_nccl_backend:
+        object_tensor = object_tensor.to(current_device)
+    dist.broadcast(object_tensor, src=src, group=group)
+    # Deserialize objects using their stored sizes.
+    offset = 0
+    if my_rank != src:
+        for i, obj_size in enumerate(object_sizes_tensor):
+            obj_view = object_tensor[offset:offset + obj_size]
+            obj_view = obj_view.type(torch.uint8)
+            if obj_view.device != torch.device('cpu'):
+                obj_view = obj_view.cpu()
+            offset += obj_size
+            object_list[i] = _tensor_to_object(obj_view, obj_size)
+
+
+def broadcast_object_list(data: List[Any],
+                          src: int = 0,
+                          group: Optional[dist.ProcessGroup] = None) -> None:
+    """Broadcasts picklable objects in ``object_list`` to the whole group.
+    Similar to :func:`broadcast`, but Python objects can be passed in. Note
+    that all objects in ``object_list`` must be picklable in order to be
+    broadcasted.
+    Note:
+        Calling ``broadcast_object_list`` in non-distributed environment does
+        nothing.
+    Args:
+        data (List[Any]): List of input objects to broadcast.
+            Each object must be picklable. Only objects on the ``src`` rank
+            will be broadcast, but each rank must provide lists of equal sizes.
+        src (int): Source rank from which to broadcast ``object_list``.
+        group: (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
+        device (``torch.device``, optional): If not None, the objects are
+            serialized and converted to tensors which are moved to the
+            ``device`` before broadcasting. Default is ``None``.
+    Note:
+        For NCCL-based process groups, internal tensor representations of
+        objects must be moved to the GPU device before communication starts.
+        In this case, the used device is given by
+        ``torch.cuda.current_device()`` and it is the user's responsibility to
+        ensure that this is correctly set so that each rank has an individual
+        GPU, via ``torch.cuda.set_device()``.
+    Examples:
+        >>> import torch
+        >>> import mmrazor.core.utils as dist
+        >>> # non-distributed environment
+        >>> data = ['foo', 12, {1: 2}]
+        >>> dist.broadcast_object_list(data)
+        >>> data
+        ['foo', 12, {1: 2}]
+        >>> # distributed environment
+        >>> # We have 2 process groups, 2 ranks.
+        >>> if dist.get_rank() == 0:
+        >>>     # Assumes world_size of 3.
+        >>>     data = ["foo", 12, {1: 2}]  # any picklable object
+        >>> else:
+        >>>     data = [None, None, None]
+        >>> dist.broadcast_object_list(data)
+        >>> data
+        ["foo", 12, {1: 2}]  # Rank 0
+        ["foo", 12, {1: 2}]  # Rank 1
+    """
+    warnings.warn(
+        '`broadcast_object_list` is now without return value, '
+        'and it\'s input parameters are: `data`,`src` and '
+        '`group`, but its function is similar to the old\'s', UserWarning)
+    assert isinstance(data, list)
+
+    if get_world_size(group) > 1:
+        if group is None:
+            group = get_default_group()
 
-    return object_list_return
+        if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+            dist.broadcast_object_list(data, src, group)
+        else:
+            _broadcast_object_list(data, src, group)
diff --git a/mmrazor/core/utils/utils.py b/mmrazor/core/utils/utils.py
new file mode 100644
index 000000000..58b87d2e5
--- /dev/null
+++ b/mmrazor/core/utils/utils.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+from torch import distributed as dist
+
+
+def is_distributed() -> bool:
+    """Return True if distributed environment has been initialized."""
+    return dist.is_available() and dist.is_initialized()
+
+
+def get_default_group() -> Optional[dist.ProcessGroup]:
+    """Return default process group."""
+
+    return dist.distributed_c10d._get_default_group()
+
+
+def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
+    """Return the rank of the given process group.
+
+    Rank is a unique identifier assigned to each process within a distributed
+    process group. They are always consecutive integers ranging from 0 to
+    ``world_size``.
+    Note:
+        Calling ``get_rank`` in non-distributed environment will return 0.
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Defaults to None.
+    Returns:
+        int: Return the rank of the process group if in distributed
+        environment, otherwise 0.
+    """
+
+    if is_distributed():
+        # handle low versions of torch like 1.5.0 which does not support
+        # passing in None for group argument
+        if group is None:
+            group = get_default_group()
+        return dist.get_rank(group)
+    else:
+        return 0
+
+
+def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
+    """Return the backend of the given process group.
+
+    Note:
+        Calling ``get_backend`` in non-distributed environment will return
+        None.
+    Args:
+        group (ProcessGroup, optional): The process group to work on. The
+            default is the general main process group. If another specific
+            group is specified, the calling process must be part of
+            :attr:`group`. Defaults to None.
+    Returns:
+        str or None: Return the backend of the given process group as a lower
+        case string if in distributed environment, otherwise None.
+    """
+    if is_distributed():
+        # handle low versions of torch like 1.5.0 which does not support
+        # passing in None for group argument
+        if group is None:
+            group = get_default_group()
+        return dist.get_backend(group)
+    else:
+        return None
+
+
+def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
+    """Return the number of the given process group.
+
+    Note:
+        Calling ``get_world_size`` in non-distributed environment will return
+        1.
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Defaults to None.
+    Returns:
+        int: Return the number of processes of the given process group if in
+        distributed environment, otherwise 1.
+    """
+    if is_distributed():
+        # handle low versions of torch like 1.5.0 which does not support
+        # passing in None for group argument
+        if group is None:
+            group = get_default_group()
+        return dist.get_world_size(group)
+    else:
+        return 1
diff --git a/mmrazor/models/pruners/ratio_pruning.py b/mmrazor/models/pruners/ratio_pruning.py
index 218b142d3..151677d99 100644
--- a/mmrazor/models/pruners/ratio_pruning.py
+++ b/mmrazor/models/pruners/ratio_pruning.py
@@ -2,6 +2,7 @@
 import numpy as np
 import torch
 import torch.nn as nn
+from torch.nn.modules import GroupNorm
 
 from mmrazor.models.builder import PRUNERS
 from .structure_pruning import StructurePruner
@@ -31,6 +32,22 @@ def __init__(self, ratios, **kwargs):
         self.ratios = ratios
         self.min_ratio = ratios[0]
 
+    def _check_pruner(self, supernet):
+        for module in supernet.model.modules():
+            if isinstance(module, GroupNorm):
+                num_channels = module.num_channels
+                num_groups = module.num_groups
+                for ratio in self.ratios:
+                    new_channels = int(round(num_channels * ratio))
+                    assert (num_channels * ratio) % num_groups == 0, \
+                        f'Expected number of channels in input of GroupNorm ' \
+                        f'to be divisible by num_groups, but number of ' \
+                        f'channels may be {new_channels} according to ' \
+                        f'ratio {ratio} and num_groups={num_groups}'
+
+    def prepare_from_supernet(self, supernet):
+        super(RatioPruner, self).prepare_from_supernet(supernet)
+
     def get_channel_mask(self, out_mask):
         """Randomly choose a width ratio of a layer from ``ratios``"""
         out_channels = out_mask.size(1)
diff --git a/mmrazor/models/pruners/structure_pruning.py b/mmrazor/models/pruners/structure_pruning.py
index af1eae2b0..b755c2e5d 100644
--- a/mmrazor/models/pruners/structure_pruning.py
+++ b/mmrazor/models/pruners/structure_pruning.py
@@ -6,9 +6,12 @@
 
 import torch
 import torch.nn as nn
+from mmcv import digit_version
 from mmcv.runner import BaseModule
 from ordered_set import OrderedSet
+from torch.nn.modules import GroupNorm
 from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.instancenorm import _InstanceNorm
 
 from mmrazor.models.builder import PRUNERS
 from .utils import SwitchableBatchNorm2d
@@ -19,14 +22,13 @@
 FC = ('ThAddmmBackward', 'AddmmBackward', 'MmBackward')
 BN = ('ThnnBatchNormBackward', 'CudnnBatchNormBackward',
       'NativeBatchNormBackward')
+GN = ('NativeGroupNormBackward', )
 CONCAT = ('CatBackward', )
 # the modules which contains NON_PASS grad_fn need to change the parameter size
 # according to channels after pruning
 NON_PASS = CONV + FC
-NON_PASS_MODULE = (nn.Conv2d, nn.Linear)
-
-PASS = BN
-PASS_MODULE = (_BatchNorm)
+PASS = BN + GN
+NORM = BN + GN
 
 BACKWARD_PARSER_DICT = dict()
 MAKE_GROUP_PARSER_DICT = dict()
@@ -122,6 +124,12 @@ def prepare_from_supernet(self, supernet):
         tmp_shared_module_hook_handles = list()
 
         for name, module in supernet.model.named_modules():
+            if isinstance(module, nn.GroupNorm):
+                min_required_version = '1.6.0'
+                assert digit_version(torch.__version__) >= digit_version(
+                    min_required_version
+                ), f'Requires pytorch>={min_required_version} to auto-trace' \
+                   f'GroupNorm correctly.'
             if hasattr(module, 'weight'):
                 # trace shared modules
                 module.cnt = 0
@@ -172,10 +180,10 @@ def prepare_from_supernet(self, supernet):
         self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module,
                                  cur_non_pass_path, non_pass_paths, visited)
 
-        bn_conv_links = dict()
-        self.trace_bn_conv_links(pseudo_loss.grad_fn, module2name, var2module,
-                                 bn_conv_links, visited)
-        self.bn_conv_links = bn_conv_links
+        norm_conv_links = dict()
+        self.trace_norm_conv_links(pseudo_loss.grad_fn, module2name,
+                                   var2module, norm_conv_links, visited)
+        self.norm_conv_links = norm_conv_links
 
         # a node can be the name of a conv module or a str like 'concat_{id}'
         node2parents = self.find_node_parents(non_pass_paths)
@@ -268,12 +276,12 @@ def set_subnet(self, subnet_dict):
             module = self.name2module[module_name]
             module.out_mask = subnet_dict[space_id].to(module.out_mask.device)
 
-        for bn, conv in self.bn_conv_links.items():
-            module = self.name2module[bn]
+        for norm, conv in self.norm_conv_links.items():
+            module = self.name2module[norm]
             conv_space_id = self.get_space_id(conv)
             # conv_space_id is None means the conv layer in front of
-            # this bn module can not be pruned. So we should not set
-            # the out_mask of this bn layer
+            # this normalization module can not be pruned. So we should not set
+            # the out_mask of this normalization layer
             if conv_space_id is not None:
                 module.out_mask = subnet_dict[conv_space_id].to(
                     module.out_mask.device)
@@ -458,7 +466,9 @@ def add_pruning_attrs(self, module):
             module.register_buffer(
                 'out_mask', module.weight.new_ones((1, module.out_features), ))
             module.forward = self.modify_fc_forward(module)
-        if isinstance(module, nn.modules.batchnorm._BatchNorm):
+        if (isinstance(module, _BatchNorm)
+                or isinstance(module, _InstanceNorm)
+                or isinstance(module, GroupNorm)):
             module.register_buffer(
                 'out_mask',
                 module.weight.new_ones((1, len(module.weight), 1, 1), ))
@@ -625,15 +635,18 @@ def trace_non_pass_path(self, grad_fn, module2name, var2module, cur_path,
         else:
             result_paths.append(copy.deepcopy(cur_path))
 
-    def trace_bn_conv_links(self, grad_fn, module2name, var2module,
-                            bn_conv_links, visited):
-        """Get the convolutional layer placed before a bn layer in the model.
+    def trace_norm_conv_links(self, grad_fn, module2name, var2module,
+                              norm_conv_links, visited):
+        """Get the convolutional layer placed before a normalization layer in
+        the model.
 
         Example:
             >>> conv = nn.Conv2d(3, 3, 3)
-            >>> bn = nn.BatchNorm2d(3)
+            >>> norm = nn.BatchNorm2d(3)
             >>> pseudo_img = torch.rand(1, 3, 224, 224)
-            >>> out = bn(conv(pseudo_img))
+            >>> out = norm(conv(pseudo_img))
+            >>> print(out.grad_fn)
+            
             >>> print(out.grad_fn.next_functions)
             ((, 0),
             (, 0),
@@ -641,23 +654,60 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module,
             >>> # op.next_functions[0][0] is ThnnConv2DBackward means
             >>> # the parent of this NativeBatchNormBackward op is
             >>> # ThnnConv2DBackward
-            >>> # op.next_functions[1][0].variable is the weight of this bn
-            >>> # module
-            >>> # op.next_functions[2][0].variable is the bias of this bn
-            >>> # module
+            >>> # op.next_functions[1][0].variable is the weight of this
+            >>> # normalization module
+            >>> # op.next_functions[2][0].variable is the bias of this
+            >>> # normalization module
+
+            >>> # Things are different in InstanceNorm
+            >>> conv = nn.Conv2d(3, 3, 3)
+            >>> norm = nn.InstanceNorm2d(3, affine=True)
+            >>> out = norm(conv(pseudo_img))
+            >>> print(out.grad_fn)
+            
+            >>> print(out.grad_fn.next_functions)
+            ((, 0),)
+            >>> print(out.grad_fn.next_functions[0][0].next_functions)
+            ((, 0),
+            (, 0),
+            (, 0))
+            >>> # Hence, a dfs is necessary.
         """
-        grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
-        if grad_fn is not None:
-            is_bn_grad_fn = False
-            for fn_name in BN:
+
+        def is_norm_grad_fn(grad_fn):
+            for fn_name in NORM:
                 if type(grad_fn).__name__.startswith(fn_name):
-                    is_bn_grad_fn = True
-                    break
+                    return True
+            return False
+
+        def is_conv_grad_fn(grad_fn):
+            for fn_name in CONV:
+                if type(grad_fn).__name__.startswith(fn_name):
+                    return True
+            return False
 
-            if is_bn_grad_fn:
+        def is_leaf_grad_fn(grad_fn):
+            if type(grad_fn).__name__ == 'AccumulateGrad':
+                return True
+            return False
+
+        grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
+        if grad_fn is not None:
+            if is_norm_grad_fn(grad_fn):
                 conv_grad_fn = grad_fn.next_functions[0][0]
-                conv_var = conv_grad_fn.next_functions[1][0].variable
-                bn_var = grad_fn.next_functions[1][0].variable
+                while not is_conv_grad_fn(conv_grad_fn):
+                    conv_grad_fn = conv_grad_fn.next_functions[0][0]
+
+                leaf_grad_fn = conv_grad_fn.next_functions[1][0]
+                while not is_leaf_grad_fn(leaf_grad_fn):
+                    leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
+                conv_var = leaf_grad_fn.variable
+
+                leaf_grad_fn = grad_fn.next_functions[1][0]
+                while not is_leaf_grad_fn(leaf_grad_fn):
+                    leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
+                bn_var = leaf_grad_fn.variable
+
                 conv_module = var2module[id(conv_var)]
                 bn_module = var2module[id(bn_var)]
                 conv_name = module2name[conv_module]
@@ -666,20 +716,20 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module,
                     pass
                 else:
                     visited[bn_name] = True
-                    bn_conv_links[bn_name] = conv_name
+                    norm_conv_links[bn_name] = conv_name
 
-                    self.trace_bn_conv_links(conv_grad_fn, module2name,
-                                             var2module, bn_conv_links,
-                                             visited)
+                    self.trace_norm_conv_links(conv_grad_fn, module2name,
+                                               var2module, norm_conv_links,
+                                               visited)
 
             else:
                 # If the op is AccumulateGrad, parents is (),
                 parents = grad_fn.next_functions
                 if parents is not None:
                     for parent in parents:
-                        self.trace_bn_conv_links(parent, module2name,
-                                                 var2module, bn_conv_links,
-                                                 visited)
+                        self.trace_norm_conv_links(parent, module2name,
+                                                   var2module, norm_conv_links,
+                                                   visited)
 
     def find_backward_parser(self, grad_fn):
         for name, parser in BACKWARD_PARSER_DICT.items():
diff --git a/mmrazor/version.py b/mmrazor/version.py
index 59c76465c..fd746a6c1 100644
--- a/mmrazor/version.py
+++ b/mmrazor/version.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved
 
-__version__ = '0.3.0'
+__version__ = '0.3.1'
 
 
 def parse_version_info(version_str):
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 55ae36bc5..6934d41bd 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -2,7 +2,6 @@ docutils==0.16.0
 m2r
 myst-parser
 git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
-recommonmark
 sphinx==4.0.2
 sphinx-copybutton
 sphinx_markdown_tables
diff --git a/tests/test_models/test_pruner.py b/tests/test_models/test_pruner.py
index 835c4894f..12323022a 100644
--- a/tests/test_models/test_pruner.py
+++ b/tests/test_models/test_pruner.py
@@ -3,7 +3,7 @@
 
 import pytest
 import torch
-from mmcv import ConfigDict
+from mmcv import ConfigDict, digit_version
 
 from mmrazor.models.builder import ARCHITECTURES, PRUNERS
 
@@ -86,7 +86,7 @@ def test_ratio_pruner():
     losses = architecture(imgs, return_loss=True, gt_label=label)
     assert losses['loss'].item() > 0
 
-    # test making groups logic when there are shared modules in the model
+    # test models with shared module
     model_cfg = ConfigDict(
         type='mmdet.RetinaNet',
         backbone=dict(
@@ -159,13 +159,127 @@ def test_ratio_pruner():
     pruner = PRUNERS.build(pruner_cfg)
     pruner.prepare_from_supernet(architecture)
     subnet_dict = pruner.sample_subnet()
-    assert isinstance(subnet_dict, dict)
     pruner.set_subnet(subnet_dict)
     subnet_dict = pruner.export_subnet()
-    assert isinstance(subnet_dict, dict)
     pruner.deploy_subnet(architecture, subnet_dict)
     architecture.forward_dummy(imgs)
 
+    # test models with concat operations
+    model_cfg = ConfigDict(
+        type='mmdet.YOLOX',
+        input_size=(640, 640),
+        random_size_range=(15, 25),
+        random_size_interval=10,
+        backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
+        neck=dict(
+            type='YOLOXPAFPN',
+            in_channels=[128, 256, 512],
+            out_channels=128,
+            num_csp_blocks=1),
+        bbox_head=dict(
+            type='YOLOXHead',
+            num_classes=80,
+            in_channels=128,
+            feat_channels=128),
+        train_cfg=dict(
+            assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
+        # In order to align the source code, the threshold of the val phase is
+        # 0.01, and the threshold of the test phase is 0.001.
+        test_cfg=dict(
+            score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
+
+    architecture_cfg = dict(
+        type='MMDetArchitecture',
+        model=model_cfg,
+    )
+
+    architecture = ARCHITECTURES.build(architecture_cfg)
+    pruner.prepare_from_supernet(architecture)
+    subnet_dict = pruner.sample_subnet()
+    pruner.set_subnet(subnet_dict)
+    subnet_dict = pruner.export_subnet()
+    pruner.deploy_subnet(architecture, subnet_dict)
+    architecture.forward_dummy(imgs)
+
+    # test models with groupnorm
+    model_cfg = ConfigDict(
+        type='mmdet.ATSS',
+        backbone=dict(
+            type='ResNet',
+            depth=50,
+            num_stages=4,
+            out_indices=(0, 1, 2, 3),
+            frozen_stages=1,
+            norm_cfg=dict(type='BN', requires_grad=True),
+            norm_eval=True,
+            style='pytorch',
+            init_cfg=dict(
+                type='Pretrained', checkpoint='torchvision://resnet50')),
+        neck=dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            start_level=1,
+            add_extra_convs='on_output',
+            num_outs=5),
+        bbox_head=dict(
+            type='ATSSHead',
+            num_classes=80,
+            in_channels=256,
+            stacked_convs=4,
+            feat_channels=256,
+            anchor_generator=dict(
+                type='AnchorGenerator',
+                ratios=[1.0],
+                octave_base_scale=8,
+                scales_per_octave=1,
+                strides=[8, 16, 32, 64, 128]),
+            bbox_coder=dict(
+                type='DeltaXYWHBBoxCoder',
+                target_means=[.0, .0, .0, .0],
+                target_stds=[0.1, 0.1, 0.2, 0.2]),
+            loss_cls=dict(
+                type='FocalLoss',
+                use_sigmoid=True,
+                gamma=2.0,
+                alpha=0.25,
+                loss_weight=1.0),
+            loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
+            loss_centerness=dict(
+                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
+        # training and testing settings
+        train_cfg=dict(
+            assigner=dict(type='ATSSAssigner', topk=9),
+            allowed_border=-1,
+            pos_weight=-1,
+            debug=False),
+        test_cfg=dict(
+            nms_pre=1000,
+            min_bbox_size=0,
+            score_thr=0.05,
+            nms=dict(type='nms', iou_threshold=0.6),
+            max_per_img=100))
+
+    architecture_cfg = dict(
+        type='MMDetArchitecture',
+        model=model_cfg,
+    )
+
+    architecture = ARCHITECTURES.build(architecture_cfg)
+    # ``StructurePruner`` requires pytorch>=1.6.0 to auto-trace GroupNorm
+    # correctly
+    min_required_version = '1.6.0'
+    if digit_version(torch.__version__) < digit_version(min_required_version):
+        with pytest.raises(AssertionError):
+            pruner.prepare_from_supernet(architecture)
+    else:
+        pruner.prepare_from_supernet(architecture)
+        subnet_dict = pruner.sample_subnet()
+        pruner.set_subnet(subnet_dict)
+        subnet_dict = pruner.export_subnet()
+        pruner.deploy_subnet(architecture, subnet_dict)
+        architecture.forward_dummy(imgs)
+
 
 def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail):
     import os
diff --git a/tools/mmcls/test_mmcls.py b/tools/mmcls/test_mmcls.py
index 8ee6351f4..26b55077f 100644
--- a/tools/mmcls/test_mmcls.py
+++ b/tools/mmcls/test_mmcls.py
@@ -120,15 +120,32 @@ def main():
         init_dist(args.launcher, **cfg.dist_params)
 
     # build the dataloader
-    dataset = build_dataset(cfg.data.test)
-    # the extra round_up data will be removed during gpu/cpu collect
-    data_loader = build_dataloader(
-        dataset,
-        samples_per_gpu=cfg.data.samples_per_gpu,
-        workers_per_gpu=cfg.data.workers_per_gpu,
+    dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
+
+    # build the dataloader
+    # The default loader config
+    loader_cfg = dict(
+        # cfg.gpus will be ignored if distributed
+        num_gpus=len(cfg.gpu_ids),
         dist=distributed,
-        shuffle=False,
-        round_up=True)
+        round_up=True,
+    )
+    # The overall dataloader settings
+    loader_cfg.update({
+        k: v
+        for k, v in cfg.data.items() if k not in [
+            'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
+            'test_dataloader'
+        ]
+    })
+    test_loader_cfg = {
+        **loader_cfg,
+        'shuffle': False,  # Not shuffle by default
+        'sampler_cfg': None,  # Not use sampler by default
+        **cfg.data.get('test_dataloader', {}),
+    }
+    # the extra round_up data will be removed during gpu/cpu collect
+    data_loader = build_dataloader(dataset, **test_loader_cfg)
 
     # build the algorithm and load checkpoint
     algorithm = build_algorithm(cfg.algorithm)
diff --git a/tools/mmseg/test_mmseg.py b/tools/mmseg/test_mmseg.py
index 3528ff8d1..50a06eb4e 100644
--- a/tools/mmseg/test_mmseg.py
+++ b/tools/mmseg/test_mmseg.py
@@ -154,12 +154,28 @@ def main():
     # build the dataloader
     # TODO: support multiple images per gpu (only minor changes are needed)
     dataset = build_dataset(cfg.data.test)
-    data_loader = build_dataloader(
-        dataset,
-        samples_per_gpu=1,
-        workers_per_gpu=cfg.data.workers_per_gpu,
+    # The default loader config
+    loader_cfg = dict(
+        # cfg.gpus will be ignored if distributed
+        num_gpus=len(cfg.gpu_ids),
         dist=distributed,
         shuffle=False)
+    # The overall dataloader settings
+    loader_cfg.update({
+        k: v
+        for k, v in cfg.data.items() if k not in [
+            'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
+            'test_dataloader'
+        ]
+    })
+    test_loader_cfg = {
+        **loader_cfg,
+        'samples_per_gpu': 1,
+        'shuffle': False,  # Not shuffle by default
+        **cfg.data.get('test_dataloader', {})
+    }
+    # build the dataloader
+    data_loader = build_dataloader(dataset, **test_loader_cfg)
 
     # build the algorithm and load checkpoint
     # Difference from mmsegmentation
diff --git a/tools/mmseg/train_mmseg.py b/tools/mmseg/train_mmseg.py
index d80978d92..b7c6e9ffd 100644
--- a/tools/mmseg/train_mmseg.py
+++ b/tools/mmseg/train_mmseg.py
@@ -177,7 +177,7 @@ def main():
     seed = seed + dist.get_rank() if args.diff_seed else seed
     logger.info(f'Set random seed to {args.seed}, deterministic: '
                 f'{args.deterministic}')
-    set_random_seed(args.seed, deterministic=args.deterministic)
+    set_random_seed(seed, deterministic=args.deterministic)
     cfg.seed = seed
     meta['seed'] = args.seed
     meta['exp_name'] = osp.basename(args.config)