Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support PGD Head #964

Merged
merged 56 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
46d420f
[Refactor] Main code modification for coordinate system refactor (#677)
yezhen17 Jul 22, 2021
6877232
[Enhance] Add script for data update (#774)
yezhen17 Jul 29, 2021
febd2eb
fix import (#839)
yezhen17 Aug 5, 2021
3f64754
[Enhance] refactor iou_neg_piecewise_sampler.py (#842)
xiliu8006 Aug 9, 2021
b2abf1e
[Feature] Add roipooling cuda ops (#843)
xiliu8006 Aug 9, 2021
5a07dfe
[Refactor] Refactor code structure and docstrings (#803)
yezhen17 Aug 11, 2021
3d9268b
[Feature] PointXYZWHLRBBoxCoder (#856)
xiliu8006 Aug 16, 2021
3389237
[Enhance] Change Groupfree3D config (#855)
yezhen17 Aug 13, 2021
fc4bb0c
[Doc] Add tutorials/data_pipeline Chinese version (#827)
wHao-Wu Aug 18, 2021
0f81e49
[Doc] Add Chinese doc for `scannet_det.md` (#836)
yezhen17 Aug 18, 2021
cac2ef8
[Doc] Add Chinese doc for `waymo_det.md` (#859)
yezhen17 Aug 18, 2021
8d0b12a
Remove 2D annotations on Lyft (#867)
Tai-Wang Aug 18, 2021
2375d3c
Add header for files (#869)
DCNSW Aug 19, 2021
63fd399
[fix] fix typos (#872)
xieenze Aug 19, 2021
884b593
Fix 3 unworking configs (#882)
yezhen17 Aug 24, 2021
d688160
[Fix] Fix `index.rst` for Chinese docs (#873)
yezhen17 Aug 24, 2021
a000db5
[Fix] Centerpoint head nested list transpose (#879)
robin-karlsson0 Aug 25, 2021
4e9c992
[Enhance] Update PointFusion (#791)
filaPro Aug 25, 2021
08dae04
[Doc] Add nuscenes_det.md Chinese version (#854)
ZCMax Aug 26, 2021
93de7c2
[Fix] Fix RegNet pretrained weight loading (#889)
yezhen17 Aug 27, 2021
0eb7e71
Fix centerpoint tta (#892)
yezhen17 Aug 30, 2021
00c037a
[Enhance] Add benchmark regression script (#808)
yezhen17 Aug 30, 2021
1e6cdea
Initial commit
yezhen17 Sep 1, 2021
d4b1244
Merge pull request #899 from THU17cyz/coord_sys_tutorial_again
yezhen17 Sep 1, 2021
f095eb6
[Feature] Support DGCNN (v1.0.0.dev0) (#896)
DCNSW Sep 3, 2021
459c637
Change cam rot_3d_in_axis (#906)
yezhen17 Sep 6, 2021
2ae6b55
[Doc] Add coord sys tutorial pic and change links to dev branch (#912)
yezhen17 Sep 7, 2021
fce176f
[Feature] add kitti AP40 evaluation metric (v1.0.0.dev0) (#927)
ZCMax Sep 13, 2021
66f0c07
[Feature] add smoke backbone neck (#939)
ZCMax Sep 15, 2021
0b26a9a
[Refactor] Refactor the transformation from image to camera coordinat…
Tai-Wang Sep 15, 2021
911a333
[Feature] FCOS3D BBox Coder (#940)
Tai-Wang Sep 15, 2021
0899bad
Support PGD BBox Coder
Tai-Wang Sep 22, 2021
b217f7a
Refine docstring
Tai-Wang Sep 22, 2021
d28a8b5
Add uncertain l1 loss and its unit tests
Tai-Wang Sep 22, 2021
2282fca
Merge branch 'uncertain_loss' into pgd_head
Tai-Wang Sep 22, 2021
506f929
[Feature] PGD BBox Coder (#948)
Tai-Wang Sep 22, 2021
38f75f5
PGD Head initialized
Tai-Wang Sep 22, 2021
89be05c
Refactor init methods, fix legacy variable names
Tai-Wang Sep 22, 2021
5be3d11
[Feature] Support Uncertain L1 Loss (#950)
Tai-Wang Sep 22, 2021
4a804bf
[Fix] Fix visualization in KITTI dataset (#956)
ZCMax Sep 22, 2021
038c39d
Refine variable names and docstrings
Tai-Wang Sep 23, 2021
0061d89
Add unit tests and fix some minor bugs
Tai-Wang Sep 24, 2021
0e1f4ed
Refine assertion messages
Tai-Wang Sep 24, 2021
c0a5021
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Sep 24, 2021
e3690c6
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Sep 24, 2021
a62e993
Fix typo in the docs_zh-CN
Tai-Wang Sep 24, 2021
efef7e9
Use Pretrain init and remove unused init_cfg in FCOS3D
Tai-Wang Sep 27, 2021
016fc29
Fix the comments for the input_modality in the dataset config
Tai-Wang Sep 29, 2021
b47ad7e
Fix minor bugs in pgd_bbox_coder and incorrect setting for uncertain …
Tai-Wang Oct 26, 2021
414f561
Merge branch 'v1.0.0.dev0' into pgd_head
Tai-Wang Oct 26, 2021
3baa41a
Add explanations for code_weights
Tai-Wang Oct 26, 2021
9658a2b
Adjust the unit test for pgd bbox coder
Tai-Wang Oct 26, 2021
fa78b41
Remove unused codes
Tai-Wang Oct 26, 2021
4e086bc
Add mono3d metric into the gather_models and fix bugs
Tai-Wang Oct 27, 2021
3f77afc
Involve the value assignment of loss_dict into the computing procedure
Tai-Wang Oct 28, 2021
b8a46bc
Fix incorrect loss_depth
Tai-Wang Oct 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

Usage:
python gather_models.py ${root_path} ${out_dir}

Example:
python gather_models.py \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d

Note that before running the above command, rename the directory with the
config name if you did not use the default directory name, create
a corresponding directory 'pgd' under the above path and put the used config
into it.
"""

import argparse
Expand Down Expand Up @@ -36,16 +46,18 @@
RESULTS_LUT = {
'coco': ['bbox_mAP', 'segm_mAP'],
'nus': ['pts_bbox_NuScenes/NDS', 'NDS'],
'kitti-3d-3class': [
'KITTI/Overall_3D_moderate',
'Overall_3D_moderate',
],
'kitti-3d-3class': ['KITTI/Overall_3D_moderate', 'Overall_3D_moderate'],
'kitti-3d-car': ['KITTI/Car_3D_moderate_strict', 'Car_3D_moderate_strict'],
'lyft': ['score'],
'scannet_seg': ['miou'],
's3dis_seg': ['miou'],
'scannet': ['mAP_0.50'],
'sunrgbd': ['mAP_0.50']
'sunrgbd': ['mAP_0.50'],
'kitti-mono3d': [
'img_bbox/KITTI/Car_3D_AP40_moderate_strict',
'Car_3D_AP40_moderate_strict'
],
'nus-mono3d': ['img_bbox_NuScenes/NDS', 'NDS']
}


Expand Down Expand Up @@ -145,15 +157,13 @@ def main():
# and parse the best performance
model_infos = []
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)

# get logs
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0]
log_json_path = glob.glob(osp.join(models_root, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(models_root, '*.log'))[0]
model_performance = get_best_results(log_json_path)
final_epoch = model_performance['epoch']
final_model = 'epoch_{}.pth'.format(final_epoch)
model_path = osp.join(exp_dir, final_model)
model_path = osp.join(models_root, final_model)

# skip if the model is still training
if not osp.exists(model_path):
Expand Down Expand Up @@ -182,7 +192,7 @@ def main():
model_name = model['config'].split('/')[-1].rstrip(
'.py') + '_' + model['model_time']
publish_model_path = osp.join(model_publish_dir, model_name)
trained_model_path = osp.join(models_root, model['config'],
trained_model_path = osp.join(models_root,
'epoch_{}.pth'.format(model['epochs']))

# convert model
Expand All @@ -191,11 +201,10 @@ def main():

# copy log
shutil.copy(
osp.join(models_root, model['config'], model['log_json_path']),
osp.join(models_root, model['log_json_path']),
osp.join(model_publish_dir, f'{model_name}.log.json'))
shutil.copy(
osp.join(models_root, model['config'],
model['log_json_path'].rstrip('.json')),
osp.join(models_root, model['log_json_path'].rstrip('.json')),
osp.join(model_publish_dir, f'{model_name}.log'))

# copy config to guarantee reproducibility
Expand Down
6 changes: 4 additions & 2 deletions configs/_base_/models/fcos3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
model = dict(
type='FCOSMono3D',
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(
type='ResNet',
depth=101,
Expand All @@ -9,7 +8,10 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet101_caffe')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
Expand Down
55 changes: 55 additions & 0 deletions configs/_base_/models/pgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = './fcos3d.py'
# model settings
model = dict(
bbox_head=dict(
_delete_=True,
type='PGDHead',
num_classes=10,
in_channels=256,
stacked_convs=2,
feat_channels=256,
use_direction_classifier=True,
diff_rad_by_sin=True,
pred_attrs=True,
pred_velo=True,
pred_bbox2d=True,
pred_keypoints=False,
dir_offset=0.7854, # pi/4
strides=[8, 16, 32, 64, 128],
group_reg_dims=(2, 1, 3, 1, 2), # offset, depth, size, rot, velo
cls_branch=(256, ),
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
() # velo
),
dir_branch=(256, ),
attr_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_attr=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
norm_on_bbox=True,
centerness_on_reg=True,
center_sampling=True,
conv_bias=True,
dcn_on_last_conv=True,
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 50),
depth_unit=10,
division='uniform',
depth_bins=6,
bbox_coder=dict(type='PGDBBoxCoder', code_size=9)),
test_cfg=dict(nms_pre=1000, nms_thr=0.8, score_thr=0.01, max_per_img=200))
127 changes: 127 additions & 0 deletions configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
_base_ = [
'../_base_/datasets/kitti-mono3d.py', '../_base_/models/pgd.py',
'../_base_/schedules/mmdet_schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
backbone=dict(frozen_stages=0),
neck=dict(start_level=0, num_outs=4),
bbox_head=dict(
num_classes=3,
bbox_code_size=7,
pred_attrs=False,
pred_velo=False,
pred_bbox2d=True,
use_onlyreg_proj=True,
strides=(4, 8, 16, 32),
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 1e8)),
group_reg_dims=(2, 1, 3, 1, 16,
4), # offset, depth, size, rot, kpts, bbox2d
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
(256, ), # kpts
(256, ) # bbox2d
),
centerness_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 70),
depth_unit=10,
division='uniform',
depth_bins=8,
pred_keypoints=True,
weight_dim=1,
loss_depth=dict(
type='UncertainSmoothL1Loss', alpha=1.0, beta=3.0,
loss_weight=1.0),
bbox_coder=dict(
type='PGDBBoxCoder',
base_depths=((28.01, 16.32), ),
base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56, 1.6)),
code_size=7)),
# set weight 1.0 for base 7 dims (offset, depth, size, rot)
# 0.2 for 16-dim keypoint offsets and 1.0 for 4-dim 2D distance targets
train_cfg=dict(code_weight=[
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0
]),
test_cfg=dict(nms_pre=100, nms_thr=0.05, score_thr=0.001, max_per_img=20))

class_names = ['Pedestrian', 'Cyclist', 'Car']
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_attr_label=False,
with_bbox_3d=True,
with_label_3d=True,
with_bbox_depth=True),
dict(type='Resize', img_scale=(1242, 375), keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d', 'gt_labels_3d',
'centers2d', 'depths'
]),
]
test_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img']),
])
]
data = dict(
samples_per_gpu=3,
workers_per_gpu=3,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
lr=0.001, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[32, 44])
total_epochs = 48
runner = dict(type='EpochBasedRunner', max_epochs=48)
evaluation = dict(interval=2)
checkpoint_config = dict(interval=8)
6 changes: 4 additions & 2 deletions mmdet3d/core/bbox/coders/pgd_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch
from torch.nn import functional as F

from mmdet.core.bbox.builder import BBOX_CODERS
Expand Down Expand Up @@ -45,8 +46,9 @@ def decode_2d(self,
scale_kpts = scale[3]
# 2 dimension of offsets x 8 corners of a 3D bbox
bbox[:, self.bbox_code_size:self.bbox_code_size + 16] = \
scale_kpts(clone_bbox[
:, self.bbox_code_size:self.bbox_code_size + 16]).float()
torch.tanh(scale_kpts(clone_bbox[
:, self.bbox_code_size:self.bbox_code_size + 16]).float())

if pred_bbox2d:
scale_bbox2d = scale[-1]
# The last four dimensions are offsets to four sides of a 2D bbox
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .free_anchor3d_head import FreeAnchor3DHead
from .groupfree3d_head import GroupFree3DHead
from .parta2_rpn_head import PartA2RPNHead
from .pgd_head import PGDHead
from .shape_aware_head import ShapeAwareHead
from .smoke_mono3d_head import SMOKEMono3DHead
from .ssd_3d_head import SSD3DHead
Expand All @@ -17,5 +18,5 @@
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead', 'SMOKEMono3DHead'
'GroupFree3DHead', 'SMOKEMono3DHead', 'PGDHead'
]
35 changes: 27 additions & 8 deletions mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,6 @@ def __init__(
self.attr_branch = attr_branch

self._init_layers()
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

def _init_layers(self):
"""Initialize layers of the head."""
Expand Down Expand Up @@ -288,8 +281,34 @@ def _init_predictor(self):
self.conv_attr = nn.Conv2d(self.attr_branch[-1], self.num_attrs, 1)

def init_weights(self):
super().init_weights()
"""Initialize weights of the head.

We currently still use the customized defined init_weights because the
default init of DCN triggered by the init_cfg will init
conv_offset.weight, which mistakenly affects the training stability.
"""
for modules in [self.cls_convs, self.reg_convs, self.conv_cls_prev]:
for m in modules:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
for conv_reg_prev in self.conv_reg_prevs:
if conv_reg_prev is None:
continue
for m in conv_reg_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
if self.use_direction_classifier:
for m in self.conv_dir_cls_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
if self.pred_attrs:
for m in self.conv_attr_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
for conv_reg in self.conv_regs:
normal_init(conv_reg, std=0.01)
if self.use_direction_classifier:
normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls)
if self.pred_attrs:
Expand Down
Loading