Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] upgrade PointPillars performace on dev branch #1166

Merged
merged 14 commits into from
Feb 17, 2022
9 changes: 5 additions & 4 deletions configs/_base_/models/hv_pointpillars_secfpn_kitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
in_channels=384,
feat_channels=384,
use_direction_classifier=True,
assign_per_class=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
type='AlignedAnchor3DRangeGenerator',
ranges=[
[0, -39.68, -0.6, 70.4, 39.68, -0.6],
[0, -39.68, -0.6, 70.4, 39.68, -0.6],
[0, -39.68, -1.78, 70.4, 39.68, -1.78],
[0, -39.68, -0.6, 69.12, 39.68, -0.6],
[0, -39.68, -0.6, 69.12, 39.68, -0.6],
[0, -39.68, -1.78, 69.12, 39.68, -1.78],
],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,15 @@
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)),
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10))
sample_groups=dict(Car=15, Pedestrian=15, Cyclist=15))

# PointPillars uses different augmentation hyper parameters
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='ObjectNoise',
num_try=100,
translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267]),
dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=False),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
Expand Down
25 changes: 24 additions & 1 deletion mmdet3d/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,33 @@ def get_ann_info(self, index):
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_labels (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths.
- difficulty (int): Difficulty defined by KITTI.
0, 1, 2 represent xxxxx respectively.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)

if 'plane' in info:
# convert ground plane to velodyne coordinates
reverse = np.linalg.inv(rect @ Trv2c)

(plane_norm_cam,
plane_off_cam) = (info['plane'][:3],
-info['plane'][:3] * info['plane'][3])
plane_norm_lidar = \
(reverse[:3, :3] @ plane_norm_cam[:, None])[:, 0]
plane_off_lidar = (
reverse[:3, :3] @ plane_off_cam[:, None][:, 0] +
reverse[:3, 3])
plane_lidar = np.zeros_like(plane_norm_lidar, shape=(4, ))
plane_lidar[:3] = plane_norm_lidar
plane_lidar[3] = -plane_norm_lidar.T @ plane_off_lidar
else:
plane_l = None
ZCMax marked this conversation as resolved.
Show resolved Hide resolved

difficulty = info['annos']['difficulty']
annos = info['annos']
# we need other objects to avoid collision when sample
annos = self.remove_dontcare(annos)
Expand Down Expand Up @@ -191,7 +212,9 @@ def get_ann_info(self, index):
gt_labels_3d=gt_labels_3d,
bboxes=gt_bboxes,
labels=gt_labels,
gt_names=gt_names)
gt_names=gt_names,
plane=plane_l,
difficulty=difficulty)
return anns_results

def drop_arrays_by_name(self, gt_names, used_classes):
Expand Down
11 changes: 10 additions & 1 deletion mmdet3d/datasets/pipelines/dbsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def filter_by_min_points(db_infos, min_gt_points_dict):
db_infos[name] = filtered_infos
return db_infos

def sample_all(self, gt_bboxes, gt_labels, img=None):
def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):
"""Sampling all categories of bboxes.

Args:
Expand Down Expand Up @@ -263,6 +263,15 @@ def sample_all(self, gt_bboxes, gt_labels, img=None):

gt_labels = np.array([self.cat2label[s['name']] for s in sampled],
dtype=np.long)

if ground_plane is not None:
xyz = sampled_gt_bboxes[:, :3]
dz = (ground_plane[:3][None, :] *
xyz).sum(-1) + ground_plane[3]
sampled_gt_bboxes[:, 2] -= dz
for i, s_points in enumerate(s_points_list):
s_points.tensor[:, 2].sub_(dz[i])

ret = {
'gt_labels_3d':
gt_labels,
Expand Down
15 changes: 13 additions & 2 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,17 @@ class ObjectSample(object):
sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste.
Defaults to False.
use_ground_plane (bool): Whether to use gound plane to adjust the
3D labels.
"""

def __init__(self, db_sampler, sample_2d=False):
def __init__(self, db_sampler, sample_2d=False, use_ground_plane=False):
self.sampler_cfg = db_sampler
self.sample_2d = sample_2d
if 'type' not in db_sampler.keys():
db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS)
self.use_ground_plane = use_ground_plane

@staticmethod
def remove_points_in_boxes(points, boxes):
Expand Down Expand Up @@ -301,6 +304,11 @@ def __call__(self, input_dict):
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']

if self.use_ground_plane and 'plane' in input_dict['ann_info']:
ground_plane = input_dict['ann_info']['plane']
input_dict['plane'] = ground_plane
else:
ground_plane = None
# change to float for blending operation
points = input_dict['points']
if self.sample_2d:
Expand All @@ -314,7 +322,10 @@ def __call__(self, input_dict):
img=img)
else:
sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None)
gt_bboxes_3d.tensor.numpy(),
gt_labels_3d,
img=None,
ground_plane=ground_plane)

if sampled_dict is not None:
sampled_gt_bboxes_3d = sampled_dict['gt_bboxes_3d']
Expand Down
24 changes: 17 additions & 7 deletions mmdet3d/models/voxel_encoders/pillar_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class PillarFeatureNet(nn.Module):

The network prepares the pillar features and performs forward pass
through PFNLayers.

Args:
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
Expand Down Expand Up @@ -54,7 +53,7 @@ def __init__(self,
if with_cluster_center:
in_channels += 3
if with_voxel_center:
in_channels += 2
in_channels += 3
if with_distance:
in_channels += 1
self._with_distance = with_distance
Expand Down Expand Up @@ -84,8 +83,10 @@ def __init__(self,
# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0]
self.vy = voxel_size[1]
self.vz = voxel_size[2]
self.x_offset = self.vx / 2 + point_cloud_range[0]
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.z_offset = self.vz / 2 + point_cloud_range[2]
self.point_cloud_range = point_cloud_range

@force_fp32(out_fp16=True)
Expand All @@ -97,7 +98,6 @@ def forward(self, features, num_points, coors):
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.

Returns:
torch.Tensor: Features of pillars.
"""
Expand All @@ -114,21 +114,27 @@ def forward(self, features, num_points, coors):
dtype = features.dtype
if self._with_voxel_center:
if not self.legacy:
f_center = torch.zeros_like(features[:, :, :2])
f_center = torch.zeros_like(features[:, :, :3])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
self.y_offset)
f_center[:, :, 2] = features[:, :, 2] - (
coors[:, 1].to(dtype).unsqueeze(1) * self.vz +
self.z_offset)
else:
f_center = features[:, :, :2]
f_center = features[:, :, :3]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
f_center[:, :, 2] = f_center[:, :, 2] - (
coors[:, 1].type_as(features).unsqueeze(1) * self.vz +
self.z_offset)
features_ls.append(f_center)

if self._with_distance:
Expand Down Expand Up @@ -177,6 +183,8 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
legacy (bool, optional): Whether to use the new behavior or
the original behavior. Defaults to True.
"""

def __init__(self,
Expand All @@ -188,7 +196,8 @@ def __init__(self,
voxel_size=(0.2, 0.2, 4),
point_cloud_range=(0, -40, -3, 70.4, 40, 1),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
mode='max'):
mode='max',
legacy=True):
super(DynamicPillarFeatureNet, self).__init__(
in_channels,
feat_channels,
Expand All @@ -198,7 +207,8 @@ def __init__(self,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
norm_cfg=norm_cfg,
mode=mode)
mode=mode,
legacy=legacy)
self.fp16_enabled = False
feat_channels = [self.in_channels] + list(feat_channels)
pfn_layers = []
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet,mmseg,mmdet3d
known_third_party = cv2,imageio,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,pytorch_sphinx_theme,recommonmark,requests,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,ts,waymo_open_dataset
known_third_party = cv2,imageio,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,pytorch_sphinx_theme,recommonmark,requests,scannet_utils,scipy,seaborn,shapely,skimage,sphinx,tensorflow,terminaltables,torch,trimesh,ts,waymo_open_dataset
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

Expand Down