Skip to content

Commit

Permalink
support SECOND-IoU and baseline results (#517)
Browse files Browse the repository at this point in the history
* support SECOND-IoU

* rename file

* modify readme and fix model name in __init__.py
  • Loading branch information
jihanyang authored May 10, 2021
1 parent a59df77 commit e3bec15
Show file tree
Hide file tree
Showing 6 changed files with 532 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Selected supported methods are shown in the below table. The results are the 3D
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:---------:|
| [PointPillar](tools/cfgs/kitti_models/pointpillar.yaml) |~1.2 hours| 77.28 | 52.29 | 62.68 | [model-18M](https://drive.google.com/file/d/1wMxWTpU1qUoY3DsCH31WJmvJxcjFXKlm/view?usp=sharing) |
| [SECOND](tools/cfgs/kitti_models/second.yaml) | ~1.7 hours | 78.62 | 52.98 | 67.15 | [model-20M](https://drive.google.com/file/d/1-01zsPOsqanZQqIIyy7FpNXStL3y4jdR/view?usp=sharing) |
| [SECOND-IoU](tools/cfgs/kitti_models/second_iou.yaml) | - | 79.09 | 55.74 | 71.31 | [model](https://drive.google.com/file/d/1AQkeNs4bxhvhDQ-5sEo_yvQUlfo73lsW/view?usp=sharing) |
| [PointRCNN](tools/cfgs/kitti_models/pointrcnn.yaml) | ~3 hours | 78.70 | 54.41 | 72.11 | [model-16M](https://drive.google.com/file/d/1BCX9wMn-GYAfSOPpyxf6Iv6fc0qKLSiU/view?usp=sharing)|
| [PointRCNN-IoU](tools/cfgs/kitti_models/pointrcnn_iou.yaml) | ~3 hours | 78.75 | 58.32 | 71.34 | [model-16M](https://drive.google.com/file/d/1V0vNZ3lAHpEEt0MlT80eL2f41K2tHm_D/view?usp=sharing)|
| [Part-A^2-Free](tools/cfgs/kitti_models/PartA2_free.yaml) | ~3.8 hours| 78.72 | 65.99 | 74.29 | [model-226M](https://drive.google.com/file/d/1lcUUxF8mJgZ_e-tZhP1XNQtTBuC-R0zr/view?usp=sharing) |
Expand Down
4 changes: 3 additions & 1 deletion pcdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from .pointpillar import PointPillar
from .pv_rcnn import PVRCNN
from .second_net import SECONDNet
from .second_net_iou import SECONDNetIoU

__all__ = {
'Detector3DTemplate': Detector3DTemplate,
'SECONDNet': SECONDNet,
'PartA2Net': PartA2Net,
'PVRCNN': PVRCNN,
'PointPillar': PointPillar,
'PointRCNN': PointRCNN
'PointRCNN': PointRCNN,
'SECONDNetIoU': SECONDNetIoU
}


Expand Down
177 changes: 177 additions & 0 deletions pcdet/models/detectors/second_net_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from .detector3d_template import Detector3DTemplate
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...ops.roiaware_pool3d import roiaware_pool3d_utils


class SECONDNetIoU(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()

def forward(self, batch_dict):
batch_dict['dataset_cfg'] = self.dataset.dataset_cfg
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)

if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()

ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts

def get_training_loss(self):
disp_dict = {}

loss_rpn, tb_dict = self.dense_head.get_loss()
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)

loss = loss_rpn + loss_rcnn
return loss, tb_dict, disp_dict

@staticmethod
def cal_scores_by_npoints(cls_scores, iou_scores, num_points_in_gt, cls_thresh=10, iou_thresh=100):
"""
Args:
cls_scores: (N)
iou_scores: (N)
num_points_in_gt: (N, 7+c)
cls_thresh: scalar
iou_thresh: scalar
"""
assert iou_thresh >= cls_thresh
alpha = torch.zeros(cls_scores.shape, dtype=torch.float32).cuda()
alpha[num_points_in_gt <= cls_thresh] = 0
alpha[num_points_in_gt >= iou_thresh] = 1

mask = ((num_points_in_gt > cls_thresh) & (num_points_in_gt < iou_thresh))
alpha[mask] = (num_points_in_gt[mask] - 10) / (iou_thresh - cls_thresh)

scores = (1 - alpha) * cls_scores + alpha * iou_scores

return scores

def set_nms_score_by_class(self, iou_preds, cls_preds, label_preds, score_by_class):
n_classes = torch.unique(label_preds).shape[0]
nms_scores = torch.zeros(iou_preds.shape, dtype=torch.float32).cuda()
for i in range(n_classes):
mask = label_preds == (i + 1)
class_name = self.class_names[i]
score_type = score_by_class[class_name]
if score_type == 'iou':
nms_scores[mask] = iou_preds[mask]
elif score_type == 'cls':
nms_scores[mask] = cls_preds[mask]
else:
raise NotImplementedError

return nms_scores

def post_processing(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
roi_labels: (B, num_rois) 1 .. num_classes
Returns:
"""
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
recall_dict = {}
pred_dicts = []
for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
batch_mask = index

box_preds = batch_dict['batch_box_preds'][batch_mask]
iou_preds = batch_dict['batch_cls_preds'][batch_mask]
cls_preds = batch_dict['roi_scores'][batch_mask]

src_iou_preds = iou_preds
src_box_preds = box_preds
src_cls_preds = cls_preds
assert iou_preds.shape[1] in [1, self.num_class]

if not batch_dict['cls_preds_normalized']:
iou_preds = torch.sigmoid(iou_preds)
cls_preds = torch.sigmoid(cls_preds)

if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError
else:
iou_preds, label_preds = torch.max(iou_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1

if post_process_cfg.NMS_CONFIG.get('SCORE_BY_CLASS', None) and \
post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'score_by_class':
nms_scores = self.set_nms_score_by_class(
iou_preds, cls_preds, label_preds, post_process_cfg.NMS_CONFIG.SCORE_BY_CLASS
)
elif post_process_cfg.NMS_CONFIG.get('SCORE_TYPE', None) == 'iou' or \
post_process_cfg.NMS_CONFIG.get('SCORE_TYPE', None) is None:
nms_scores = iou_preds
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'cls':
nms_scores = cls_preds
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'weighted_iou_cls':
nms_scores = post_process_cfg.NMS_CONFIG.SCORE_WEIGHTS.iou * iou_preds + \
post_process_cfg.NMS_CONFIG.SCORE_WEIGHTS.cls * cls_preds
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'num_pts_iou_cls':
point_mask = (batch_dict['points'][:, 0] == batch_mask)
batch_points = batch_dict['points'][point_mask][:, 1:4]

num_pts_in_gt = roiaware_pool3d_utils.points_in_boxes_cpu(
batch_points.cpu(), box_preds[:, 0:7].cpu()
).sum(dim=1).float().cuda()

score_thresh_cfg = post_process_cfg.NMS_CONFIG.SCORE_THRESH
nms_scores = self.cal_scores_by_npoints(
cls_preds, iou_preds, num_pts_in_gt,
score_thresh_cfg.cls, score_thresh_cfg.iou
)
else:
raise NotImplementedError

selected, selected_scores = class_agnostic_nms(
box_scores=nms_scores, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)

if post_process_cfg.OUTPUT_RAW_SCORE:
raise NotImplementedError

final_scores = selected_scores
final_labels = label_preds[selected]
final_boxes = box_preds[selected]

recall_dict = self.generate_recall_record(
box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)

record_dict = {
'pred_boxes': final_boxes,
'pred_scores': final_scores,
'pred_labels': final_labels,
'pred_cls_scores': cls_preds[selected],
'pred_iou_scores': iou_preds[selected]
}

pred_dicts.append(record_dict)

return pred_dicts, recall_dict
2 changes: 2 additions & 0 deletions pcdet/models/roi_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .partA2_head import PartA2FCHead
from .pointrcnn_head import PointRCNNHead
from .pvrcnn_head import PVRCNNHead
from .second_head import SECONDHead
from .roi_head_template import RoIHeadTemplate

__all__ = {
'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead,
'SECONDHead': SECONDHead,
'PointRCNNHead': PointRCNNHead
}
178 changes: 178 additions & 0 deletions pcdet/models/roi_heads/second_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils, loss_utils


class SECONDHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg

GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
pre_channel = self.model_cfg.ROI_GRID_POOL.IN_CHANNEL * GRID_SIZE * GRID_SIZE

shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
shared_fc_list.extend([
nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
nn.ReLU()
])
pre_channel = self.model_cfg.SHARED_FC[k]

if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))

self.shared_fc_layer = nn.Sequential(*shared_fc_list)

self.iou_layers = self.make_fc_layers(
input_channels=pre_channel, output_channels=1, fc_list=self.model_cfg.IOU_FC
)
self.init_weights(weight_init='xavier')

def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError

for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def roi_grid_pool(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
spatial_features_2d: (B, C, H, W)
Returns:
"""
batch_size = batch_dict['batch_size']
rois = batch_dict['rois'].detach()
spatial_features_2d = batch_dict['spatial_features_2d'].detach()
height, width = spatial_features_2d.size(2), spatial_features_2d.size(3)

dataset_cfg = batch_dict['dataset_cfg']
min_x = dataset_cfg.POINT_CLOUD_RANGE[0]
min_y = dataset_cfg.POINT_CLOUD_RANGE[1]
voxel_size_x = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[0]
voxel_size_y = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[1]
down_sample_ratio = self.model_cfg.ROI_GRID_POOL.DOWNSAMPLE_RATIO

pooled_features_list = []
torch.backends.cudnn.enabled = False
for b_id in range(batch_size):
# Map global boxes coordinates to feature map coordinates
x1 = (rois[b_id, :, 0] - rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio)
x2 = (rois[b_id, :, 0] + rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio)
y1 = (rois[b_id, :, 1] - rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio)
y2 = (rois[b_id, :, 1] + rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio)

angle, _ = common_utils.check_numpy_to_torch(rois[b_id, :, 6])

cosa = torch.cos(angle)
sina = torch.sin(angle)

theta = torch.stack((
(x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1),
(y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1)
), dim=1).view(-1, 2, 3).float()

grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
grid = nn.functional.affine_grid(
theta,
torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size))
)

pooled_features = nn.functional.grid_sample(
spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width),
grid
)

pooled_features_list.append(pooled_features)

torch.backends.cudnn.enabled = True
pooled_features = torch.cat(pooled_features_list, dim=0)

return pooled_features

def forward(self, batch_dict):
"""
:param input_data: input dict
:return:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']

# RoI aware pooling
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, C, 7, 7)
batch_size_rcnn = pooled_features.shape[0]

shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B*N, 1)

if not self.training:
batch_dict['batch_cls_preds'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1])
batch_dict['batch_box_preds'] = batch_dict['rois']
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_iou'] = rcnn_iou

self.forward_ret_dict = targets_dict

return batch_dict

def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
rcnn_loss = 0
rcnn_loss_cls, cls_tb_dict = self.get_box_iou_layer_loss(self.forward_ret_dict)
rcnn_loss += rcnn_loss_cls
tb_dict.update(cls_tb_dict)

tb_dict['rcnn_loss'] = rcnn_loss.item()
return rcnn_loss, tb_dict

def get_box_iou_layer_loss(self, forward_ret_dict):
loss_cfgs = self.model_cfg.LOSS_CONFIG
rcnn_iou = forward_ret_dict['rcnn_iou']
rcnn_iou_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
rcnn_iou_flat = rcnn_iou.view(-1)
if loss_cfgs.IOU_LOSS == 'BinaryCrossEntropy':
batch_loss_iou = nn.functional.binary_cross_entropy_with_logits(
rcnn_iou_flat,
rcnn_iou_labels.float(), reduction='none'
)
elif loss_cfgs.IOU_LOSS == 'L2':
batch_loss_iou = nn.functional.mse_loss(rcnn_iou_flat, rcnn_iou_labels, reduction='none')
elif loss_cfgs.IOU_LOSS == 'smoothL1':
diff = rcnn_iou_flat - rcnn_iou_labels
batch_loss_iou = loss_utils.WeightedSmoothL1Loss.smooth_l1_loss(diff, 1.0 / 9.0)
elif loss_cfgs.IOU_LOSS == 'focalbce':
batch_loss_iou = loss_utils.sigmoid_focal_cls_loss(rcnn_iou_flat, rcnn_iou_labels)
else:
raise NotImplementedError

iou_valid_mask = (rcnn_iou_labels >= 0).float()
rcnn_loss_iou = (batch_loss_iou * iou_valid_mask).sum() / torch.clamp(iou_valid_mask.sum(), min=1.0)

rcnn_loss_iou = rcnn_loss_iou * loss_cfgs.LOSS_WEIGHTS['rcnn_iou_weight']
tb_dict = {'rcnn_loss_iou': rcnn_loss_iou.item()}
return rcnn_loss_iou, tb_dict
Loading

4 comments on commit e3bec15

@colian
Copy link

@colian colian commented on e3bec15 May 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided model for SECOND_IOU cannot be reproduced. Any other information about SCORE_TYPE?

@jihanyang
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check it later.

@jihanyang
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided model for SECOND_IOU cannot be reproduced. Any other information about SCORE_TYPE?

I reproduce the result successfully. Can you reproduce the result of second?

2021-05-25 16:48:09,630 INFO *************** EPOCH 80 EVALUATION *****************
eval: 100%|██████████| 118/118 [00:30<00:00, 3.81it/s, recall_0.3=(2136, 2136) / 2210]
2021-05-25 16:48:41,134 INFO *************** Performance of EPOCH 80 *****************
2021-05-25 16:48:41,137 INFO Generate label finished(sec_per_example: 0.0084 second).
2021-05-25 16:48:41,137 INFO recall_roi_0.3: 0.971216
2021-05-25 16:48:41,137 INFO recall_rcnn_0.3: 0.971216
2021-05-25 16:48:41,137 INFO recall_roi_0.5: 0.933273
2021-05-25 16:48:41,137 INFO recall_rcnn_0.5: 0.933273
2021-05-25 16:48:41,138 INFO recall_roi_0.7: 0.712612
2021-05-25 16:48:41,138 INFO recall_rcnn_0.7: 0.712612
2021-05-25 16:48:41,141 INFO Average predicted number of objects(3769 samples): 13.678
2021-05-25 16:49:09,050 INFO Car AP@0.70, 0.70, 0.70:
bbox AP:90.8298, 89.8464, 89.2890
bev AP:90.2490, 88.0956, 87.1298
3d AP:89.2254, 79.1751, 78.2560
aos AP:90.81, 89.73, 89.10
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:95.9564, 94.4257, 92.0571
bev AP:92.8637, 89.9651, 88.1752
3d AP:91.1837, 82.4682, 79.8206
aos AP:95.92, 94.28, 91.84
Car AP@0.70, 0.50, 0.50:
bbox AP:90.8298, 89.8464, 89.2890
bev AP:90.8229, 89.9098, 89.4771
3d AP:90.8229, 89.8907, 89.4380
aos AP:90.81, 89.73, 89.10
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:95.9564, 94.4257, 92.0571
bev AP:95.9862, 94.7134, 94.3089
3d AP:95.9769, 94.6662, 94.1957
aos AP:95.92, 94.28, 91.84
Pedestrian AP@0.50, 0.50, 0.50:
bbox AP:72.6088, 67.1919, 64.0639
bev AP:67.5701, 60.5777, 56.7529
3d AP:62.7718, 56.6545, 51.3505
aos AP:68.24, 62.80, 59.55
Pedestrian AP_R40@0.50, 0.50, 0.50:
bbox AP:73.3103, 67.7723, 64.1540
bev AP:67.7177, 60.3325, 55.6805
3d AP:63.2151, 55.7024, 50.3015
aos AP:68.40, 62.82, 59.10
Pedestrian AP@0.50, 0.25, 0.25:
bbox AP:72.6088, 67.1919, 64.0639
bev AP:79.2449, 74.9767, 71.1995
3d AP:78.8350, 74.7658, 70.7464
aos AP:68.24, 62.80, 59.55
Pedestrian AP_R40@0.50, 0.25, 0.25:
bbox AP:73.3103, 67.7723, 64.1540
bev AP:80.5269, 76.2809, 72.2288
3d AP:80.2924, 76.0266, 71.6252
aos AP:68.40, 62.82, 59.10
Cyclist AP@0.50, 0.50, 0.50:
bbox AP:92.6116, 82.0762, 78.2363
bev AP:91.3906, 75.8321, 71.1691
3d AP:86.9092, 70.9220, 66.8777
aos AP:92.29, 81.11, 77.37
Cyclist AP_R40@0.50, 0.50, 0.50:
bbox AP:95.1193, 84.1542, 79.6162
bev AP:93.6958, 77.0219, 72.1574
3d AP:90.3740, 71.6101, 67.0025
aos AP:94.78, 83.16, 78.63
Cyclist AP@0.50, 0.25, 0.25:
bbox AP:92.6116, 82.0762, 78.2363
bev AP:91.8523, 79.3298, 75.3813
3d AP:91.8523, 79.3298, 75.3813
aos AP:92.29, 81.11, 77.37
Cyclist AP_R40@0.50, 0.25, 0.25:
bbox AP:95.1193, 84.1542, 79.6162
bev AP:94.4396, 81.0085, 76.5129
3d AP:94.4396, 81.0085, 76.5129
aos AP:94.78, 83.16, 78.63

@MartinHahner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also able to reproduce the numbers.
Shown here are the R40 results (easy, moderate, hard from left to right).
image

Please sign in to comment.