Skip to content

Commit

Permalink
Fix type error in 2_new_data_mode (#6469)
Browse files Browse the repository at this point in the history
* Always map location to cpu when load checkpoint (#6405)

* configs: update groie README (#6401)

Signed-off-by: Leonardo Rossi <leonardo.rossi@unipr.it>

* [Fix] fix config path in docs (#6396)

* [Enchance] Set a random seed when the user does not set a seed. (#6457)

* fix random seed bug

* add comment

* enchance random seed

* rename

Co-authored-by: Haobo Yuan <yuanhaobo@whu.edu.cn>

* [BugFixed] fix wrong trunc_normal_init use (#6432)

* fix wrong trunc_normal_init use

* fix wrong trunc_normal_init use

* fix #6446

Co-authored-by: Uno Wu <st9007a@gmail.com>
Co-authored-by: Leonardo Rossi <leonardo.rossi@unipr.it>
Co-authored-by: BigDong <yudongwang@tju.edu.cn>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
Co-authored-by: Haobo Yuan <yuanhaobo@whu.edu.cn>
Co-authored-by: Shusheng Yang <shusheng.yang@qq.com>
  • Loading branch information
7 people committed Nov 15, 2021
1 parent 0f490a1 commit be174ab
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 36 deletions.
21 changes: 9 additions & 12 deletions configs/groie/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ performance.
## Results and models

The results on COCO 2017 minival (5k images) are shown in the below table.
You can find
[here](https://drive.google.com/drive/folders/19ssstbq_h0Z1cgxHmJYFO8s1arf3QJbT)
the trained models.

### Application of GRoIE to different architectures

Expand All @@ -42,24 +39,24 @@ the trained models.
| R-50-FPN | GC-Net | 1x | 40.7 | 36.5 | [config](../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200202-50b90e5c.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200202_085547.log.json) |
| R-50-FPN | + GRoIE | 1x | 41.0 | 37.8 | [config](./mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200604_211715-42eb79e1.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200604_211715-42eb79e1.pth) |
| R-101-FPN | GC-Net | 1x | 42.2 | 37.8 | [config](../gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200206-8407a3f0.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco_20200206_142508.log.json) |
| R-101-FPN | + GRoIE | 1x | | | [config](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507-8daae01c.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507.log.json) |
| R-101-FPN | + GRoIE | 1x | 42.6 | 38.7 | [config](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507-8daae01c.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco_20200607_224507.log.json) |

## Citation

If you use this work or benchmark in your research, please cite this project.

```latex
@misc{rossi2020novel,
title={A novel Region of Interest Extraction Layer for Instance Segmentation},
author={Leonardo Rossi and Akbar Karimi and Andrea Prati},
year={2020},
eprint={2004.13665},
archivePrefix={arXiv},
primaryClass={cs.CV}
@inproceedings{rossi2021novel,
title={A novel region of interest extraction layer for instance segmentation},
author={Rossi, Leonardo and Karimi, Akbar and Prati, Andrea},
booktitle={2020 25th International Conference on Pattern Recognition (ICPR)},
pages={2203--2209},
year={2021},
organization={IEEE}
}
```

## Contact

The implementation of GROI is currently maintained by
The implementation of GRoIE is currently maintained by
[Leonardo Rossi](https://github.com/hachreak/).
2 changes: 1 addition & 1 deletion docs/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ Assuming that you have already downloaded the checkpoints to the directory `chec

```shell
python tools/test.py \
configs/faster_rcnn/faster_rcnn_r50_fpn_1x.py \
configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
--show-dir faster_rcnn_r50_fpn_1x_results
```
Expand Down
2 changes: 1 addition & 1 deletion docs/2_new_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ For more detailed usages, please refer to the [Case 1](1_exist_data_model.md).
To test the trained model, you can simply run

```shell
python tools/test.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py/latest.pth --eval bbox segm
python tools/test.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon/latest.pth --eval bbox segm
```

For more detailed usages, please refer to the [Case 1](1_exist_data_model.md).
2 changes: 1 addition & 1 deletion docs_zh-CN/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bash tools/dist_test.sh \

```shell
python tools/test.py \
configs/faster_rcnn/faster_rcnn_r50_fpn_1x.py \
configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
--show-dir faster_rcnn_r50_fpn_1x_results
```
Expand Down
5 changes: 3 additions & 2 deletions mmdet/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .inference import (async_inference_detector, inference_detector,
init_detector, show_result_pyplot)
from .test import multi_gpu_test, single_gpu_test
from .train import get_root_logger, set_random_seed, train_detector
from .train import (get_root_logger, init_random_seed, set_random_seed,
train_detector)

__all__ = [
'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
'async_inference_detector', 'inference_detector', 'show_result_pyplot',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'init_random_seed'
]
3 changes: 1 addition & 2 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
Expand Down
36 changes: 35 additions & 1 deletion mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg

from mmdet.core import DistEvalHook, EvalHook
Expand All @@ -16,6 +17,39 @@
from mmdet.utils import get_root_logger


def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.

If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.

Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.

Returns:
int: Seed to be used.
"""
if seed is not None:
return seed

# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()


def set_random_seed(seed, deterministic=False):
"""Set random seed.

Expand Down
7 changes: 3 additions & 4 deletions mmdet/models/backbones/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
constant_init, normal_init, trunc_normal_init)
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
load_state_dict)
from torch.nn.modules.utils import _pair as to_2tuple
Expand Down Expand Up @@ -315,7 +316,7 @@ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
self.drop = nn.Dropout(p=drop_rate)

def init_weights(self):
trunc_normal_init(self.pos_embed, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)

def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
"""Resize pos_embed weights.
Expand Down Expand Up @@ -526,9 +527,7 @@ def init_weights(self):
f'training start from scratch')
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
Expand Down
9 changes: 4 additions & 5 deletions mmdet/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple

Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self,
self.softmax = nn.Softmax(dim=-1)

def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x, mask=None):
"""
Expand Down Expand Up @@ -672,12 +673,10 @@ def init_weights(self):
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
Expand Down
14 changes: 7 additions & 7 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mmcv.utils import get_git_hash

from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
Expand Down Expand Up @@ -148,12 +148,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

model = build_detector(
Expand Down

0 comments on commit be174ab

Please sign in to comment.