From be174ab41dc61cc38f55430e2b1aa4e0733ed73c Mon Sep 17 00:00:00 2001 From: Andrea Panizza <8233615+AndreaPi@users.noreply.github.com> Date: Mon, 15 Nov 2021 12:21:02 +0100 Subject: [PATCH] Fix type error in 2_new_data_mode (#6469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Always map location to cpu when load checkpoint (#6405) * configs: update groie README (#6401) Signed-off-by: Leonardo Rossi * [Fix] fix config path in docs (#6396) * [Enchance] Set a random seed when the user does not set a seed. (#6457) * fix random seed bug * add comment * enchance random seed * rename Co-authored-by: Haobo Yuan * [BugFixed] fix wrong trunc_normal_init use (#6432) * fix wrong trunc_normal_init use * fix wrong trunc_normal_init use * fix #6446 Co-authored-by: Uno Wu Co-authored-by: Leonardo Rossi Co-authored-by: BigDong Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> Co-authored-by: Haobo Yuan Co-authored-by: Shusheng Yang --- configs/groie/README.md | 21 ++++++++----------- docs/1_exist_data_model.md | 2 +- docs/2_new_data_model.md | 2 +- docs_zh-CN/1_exist_data_model.md | 2 +- mmdet/apis/__init__.py | 5 +++-- mmdet/apis/inference.py | 3 +-- mmdet/apis/train.py | 36 +++++++++++++++++++++++++++++++- mmdet/models/backbones/pvt.py | 7 +++---- mmdet/models/backbones/swin.py | 9 ++++---- tools/train.py | 14 ++++++------- 10 files changed, 65 insertions(+), 36 deletions(-) diff --git a/configs/groie/README.md b/configs/groie/README.md index c38b70b64e6..42d4b9feb5e 100644 --- a/configs/groie/README.md +++ b/configs/groie/README.md @@ -25,9 +25,6 @@ performance. ## Results and models The results on COCO 2017 minival (5k images) are shown in the below table. -You can find -[here](https://drive.google.com/drive/folders/19ssstbq_h0Z1cgxHmJYFO8s1arf3QJbT) -the trained models. ### Application of GRoIE to different architectures @@ -42,24 +39,24 @@ the trained models. | R-50-FPN | GC-Net | 1x | 40.7 | 36.5 | [config](../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200202-50b90e5c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200202_085547.log.json) | | R-50-FPN | + GRoIE | 1x | 41.0 | 37.8 | [config](./mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200604_211715-42eb79e1.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200604_211715-42eb79e1.pth) | | R-101-FPN | GC-Net | 1x | 42.2 | 37.8 | [config](../gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200206-8407a3f0.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200206_142508.log.json) | -| R-101-FPN | + GRoIE | 1x | | | [config](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507-8daae01c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507.log.json) | +| R-101-FPN | + GRoIE | 1x | 42.6 | 38.7 | [config](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507-8daae01c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507.log.json) | ## Citation If you use this work or benchmark in your research, please cite this project. ```latex -@misc{rossi2020novel, - title={A novel Region of Interest Extraction Layer for Instance Segmentation}, - author={Leonardo Rossi and Akbar Karimi and Andrea Prati}, - year={2020}, - eprint={2004.13665}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@inproceedings{rossi2021novel, + title={A novel region of interest extraction layer for instance segmentation}, + author={Rossi, Leonardo and Karimi, Akbar and Prati, Andrea}, + booktitle={2020 25th International Conference on Pattern Recognition (ICPR)}, + pages={2203--2209}, + year={2021}, + organization={IEEE} } ``` ## Contact -The implementation of GROI is currently maintained by +The implementation of GRoIE is currently maintained by [Leonardo Rossi](https://github.com/hachreak/). diff --git a/docs/1_exist_data_model.md b/docs/1_exist_data_model.md index fc7c286f472..4c3ad3cee4b 100644 --- a/docs/1_exist_data_model.md +++ b/docs/1_exist_data_model.md @@ -300,7 +300,7 @@ Assuming that you have already downloaded the checkpoints to the directory `chec ```shell python tools/test.py \ - configs/faster_rcnn/faster_rcnn_r50_fpn_1x.py \ + configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \ checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ --show-dir faster_rcnn_r50_fpn_1x_results ``` diff --git a/docs/2_new_data_model.md b/docs/2_new_data_model.md index a9736e7ebed..455313a20ca 100644 --- a/docs/2_new_data_model.md +++ b/docs/2_new_data_model.md @@ -257,7 +257,7 @@ For more detailed usages, please refer to the [Case 1](1_exist_data_model.md). To test the trained model, you can simply run ```shell -python tools/test.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py/latest.pth --eval bbox segm +python tools/test.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon/latest.pth --eval bbox segm ``` For more detailed usages, please refer to the [Case 1](1_exist_data_model.md). diff --git a/docs_zh-CN/1_exist_data_model.md b/docs_zh-CN/1_exist_data_model.md index 487e9be3118..9a2f0d077aa 100644 --- a/docs_zh-CN/1_exist_data_model.md +++ b/docs_zh-CN/1_exist_data_model.md @@ -285,7 +285,7 @@ bash tools/dist_test.sh \ ```shell python tools/test.py \ - configs/faster_rcnn/faster_rcnn_r50_fpn_1x.py \ + configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \ checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ --show-dir faster_rcnn_r50_fpn_1x_results ``` diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index 4a8987d1da8..a865e942afd 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -2,10 +2,11 @@ from .inference import (async_inference_detector, inference_detector, init_detector, show_result_pyplot) from .test import multi_gpu_test, single_gpu_test -from .train import get_root_logger, set_random_seed, train_detector +from .train import (get_root_logger, init_random_seed, set_random_seed, + train_detector) __all__ = [ 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector', 'async_inference_detector', 'inference_detector', 'show_result_pyplot', - 'multi_gpu_test', 'single_gpu_test' + 'multi_gpu_test', 'single_gpu_test', 'init_random_seed' ] diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 6b4b0096e5f..70dc704168f 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -39,8 +39,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): config.model.train_cfg = None model = build_detector(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: - map_loc = 'cpu' if device == 'cpu' else None - checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index b8b6eefc8ab..a7ed2d267eb 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -4,10 +4,11 @@ import numpy as np import torch +import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, Fp16OptimizerHook, OptimizerHook, build_optimizer, - build_runner) + build_runner, get_dist_info) from mmcv.utils import build_from_cfg from mmdet.core import DistEvalHook, EvalHook @@ -16,6 +17,39 @@ from mmdet.utils import get_root_logger +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + def set_random_seed(seed, deterministic=False): """Set random seed. diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py index 11c1eb0e8ca..c5365c53c8e 100644 --- a/mmdet/models/backbones/pvt.py +++ b/mmdet/models/backbones/pvt.py @@ -9,6 +9,7 @@ constant_init, normal_init, trunc_normal_init) from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint, load_state_dict) from torch.nn.modules.utils import _pair as to_2tuple @@ -315,7 +316,7 @@ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None): self.drop = nn.Dropout(p=drop_rate) def init_weights(self): - trunc_normal_init(self.pos_embed, std=0.02) + trunc_normal_(self.pos_embed, std=0.02) def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'): """Resize pos_embed weights. @@ -526,9 +527,7 @@ def init_weights(self): f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m.weight, std=.02) - if m.bias is not None: - constant_init(m.bias, 0) + trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) diff --git a/mmdet/models/backbones/swin.py b/mmdet/models/backbones/swin.py index fce01664ad0..316aec9765f 100644 --- a/mmdet/models/backbones/swin.py +++ b/mmdet/models/backbones/swin.py @@ -8,6 +8,7 @@ import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from mmcv.utils import to_2tuple @@ -74,7 +75,7 @@ def __init__(self, self.softmax = nn.Softmax(dim=-1) def init_weights(self): - trunc_normal_init(self.relative_position_bias_table, std=0.02) + trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x, mask=None): """ @@ -672,12 +673,10 @@ def init_weights(self): f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: - trunc_normal_init(self.absolute_pos_embed, std=0.02) + trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m.weight, std=.02) - if m.bias is not None: - constant_init(m.bias, 0) + trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) diff --git a/tools/train.py b/tools/train.py index f59f195332d..45fa03bfad6 100644 --- a/tools/train.py +++ b/tools/train.py @@ -13,7 +13,7 @@ from mmcv.utils import get_git_hash from mmdet import __version__ -from mmdet.apis import set_random_seed, train_detector +from mmdet.apis import init_random_seed, set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector from mmdet.utils import collect_env, get_root_logger @@ -148,12 +148,12 @@ def main(): logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds - if args.seed is not None: - logger.info(f'Set random seed to {args.seed}, ' - f'deterministic: {args.deterministic}') - set_random_seed(args.seed, deterministic=args.deterministic) - cfg.seed = args.seed - meta['seed'] = args.seed + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed meta['exp_name'] = osp.basename(args.config) model = build_detector(