-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/dev-3.x' into detic_inference
- Loading branch information
Showing
97 changed files
with
3,449 additions
and
256 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# dataset settings | ||
dataset_type = 'Objects365V1Dataset' | ||
data_root = 'data/Objects365/Obj365_v1/' | ||
|
||
# file_client_args = dict( | ||
# backend='petrel', | ||
# path_mapping=dict({ | ||
# './data/': 's3://openmmlab/datasets/detection/', | ||
# 'data/': 's3://openmmlab/datasets/detection/' | ||
# })) | ||
file_client_args = dict(backend='disk') | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', file_client_args=file_client_args), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackDetInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile', file_client_args=file_client_args), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
# If you don't have a gt annotation, delete the pipeline | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor')) | ||
] | ||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/objects365_train.json', | ||
data_prefix=dict(img='train/'), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/objects365_val.json', | ||
data_prefix=dict(img='val/'), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/objects365_val.json', | ||
metric='bbox', | ||
sort_categories=True, | ||
format_only=False) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# dataset settings | ||
dataset_type = 'Objects365V2Dataset' | ||
data_root = 'data/Objects365/Obj365_v2/' | ||
|
||
# file_client_args = dict( | ||
# backend='petrel', | ||
# path_mapping=dict({ | ||
# './data/': 's3://openmmlab/datasets/detection/', | ||
# 'data/': 's3://openmmlab/datasets/detection/' | ||
# })) | ||
file_client_args = dict(backend='disk') | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', file_client_args=file_client_args), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackDetInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile', file_client_args=file_client_args), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
# If you don't have a gt annotation, delete the pipeline | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor')) | ||
] | ||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/zhiyuan_objv2_train.json', | ||
data_prefix=dict(img='train/'), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/zhiyuan_objv2_val.json', | ||
data_prefix=dict(img='val/'), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/zhiyuan_objv2_val.json', | ||
metric='bbox', | ||
format_only=False) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# BoxInst | ||
|
||
> [BoxInst: High-Performance Instance Segmentation with Box Annotations](https://arxiv.org/pdf/2012.02310.pdf) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
We present a high-performance method that can achieve mask-level instance segmentation with only bounding-box annotations for training. While this setting has been studied in the literature, here we show significantly stronger performance with a simple design (e.g., dramatically improving previous best reported mask AP of 21.1% to 31.6% on the COCO dataset). Our core idea is to redesign the loss | ||
of learning masks in instance segmentation, with no modification to the segmentation network itself. The new loss functions can supervise the mask training without relying on mask annotations. This is made possible with two loss terms, namely, 1) a surrogate term that minimizes the discrepancy between the projections of the ground-truth box and the predicted mask; 2) a pairwise loss that can exploit the prior that proximal pixels with similar colors are very likely to have the same category label. Experiments demonstrate that the redesigned mask loss can yield surprisingly high-quality instance masks with only box annotations. For example, without using any mask annotations, with a ResNet-101 backbone and 3× training schedule, we achieve 33.2% mask AP on COCO test-dev split (vs. 39.1% of the fully supervised counterpart). Our excellent experiment results on COCO and Pascal VOC indicate that our method dramatically narrows the performance gap between weakly and fully supervised instance segmentation. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/57584090/209087723-756b76d7-5061-4000-a93c-df1194a439a0.png"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
| Backbone | Style | MS train | Lr schd | bbox AP | mask AP | Config | Download | | ||
| :------: | :-----: | :------: | :-----: | :-----: | :-----: | :----------------------------------------: | :----------------------: | | ||
| R-50 | pytorch | Y | 1x | 39.4 | 30.8 | [config](./boxinst_r50_fpn_ms-90k_coco.py) | [model](<>) \| [log](<>) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{tian2020boxinst, | ||
title = {{BoxInst}: High-Performance Instance Segmentation with Box Annotations}, | ||
author = {Tian, Zhi and Shen, Chunhua and Wang, Xinlong and Chen, Hao}, | ||
booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)}, | ||
year = {2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
_base_ = '../common/ms-90k_coco.py' | ||
|
||
# model settings | ||
model = dict( | ||
type='BoxInst', | ||
data_preprocessor=dict( | ||
type='BoxInstDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=32, | ||
mask_stride=4, | ||
pairwise_size=3, | ||
pairwise_dilation=2, | ||
pairwise_color_thresh=0.3, | ||
bottom_pixels_removed=10), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), | ||
style='pytorch'), | ||
neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', # use P5 | ||
num_outs=5, | ||
relu_before_extra_convs=True), | ||
bbox_head=dict( | ||
type='BoxInstBboxHead', | ||
num_params=593, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
strides=[8, 16, 32, 64, 128], | ||
norm_on_bbox=True, | ||
centerness_on_reg=True, | ||
dcn_on_last_conv=False, | ||
center_sampling=True, | ||
conv_bias=True, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=1.0), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), | ||
mask_head=dict( | ||
type='BoxInstMaskHead', | ||
num_layers=3, | ||
feat_channels=16, | ||
size_of_interest=8, | ||
mask_out_stride=4, | ||
topk_masks_per_img=64, | ||
mask_feature_head=dict( | ||
in_channels=256, | ||
feat_channels=128, | ||
start_level=0, | ||
end_level=2, | ||
out_channels=16, | ||
mask_stride=8, | ||
num_stacked_convs=4, | ||
norm_cfg=dict(type='BN', requires_grad=True)), | ||
loss_mask=dict( | ||
type='DiceLoss', | ||
use_sigmoid=True, | ||
activate=True, | ||
eps=5e-6, | ||
loss_weight=1.0)), | ||
# model training and testing settings | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=100, | ||
mask_thr=0.5)) | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict(lr=0.01)) | ||
|
||
# evaluator | ||
val_evaluator = dict(metric=['bbox', 'segm']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.