diff --git a/configs/_base_/models/hv_pointpillars_secfpn_kitti.py b/configs/_base_/models/hv_pointpillars_secfpn_kitti.py index 33c80a171c..ac46475d6e 100644 --- a/configs/_base_/models/hv_pointpillars_secfpn_kitti.py +++ b/configs/_base_/models/hv_pointpillars_secfpn_kitti.py @@ -34,6 +34,7 @@ in_channels=384, feat_channels=384, use_direction_classifier=True, + assign_per_class=True, anchor_generator=dict( type='AlignedAnchor3DRangeGenerator', ranges=[ diff --git a/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py b/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py index 5c5c939de5..2611e86d3a 100644 --- a/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py +++ b/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py @@ -15,21 +15,15 @@ rate=1.0, prepare=dict( filter_by_difficulty=[-1], - filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), classes=class_names, - sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10)) + sample_groups=dict(Car=15, Pedestrian=15, Cyclist=15)) # PointPillars uses different augmentation hyper parameters train_pipeline = [ dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), - dict(type='ObjectSample', db_sampler=db_sampler), - dict( - type='ObjectNoise', - num_try=100, - translation_std=[0.25, 0.25, 0.25], - global_rot_range=[0.0, 0.0], - rot_range=[-0.15707963267, 0.15707963267]), + dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=False), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), dict( type='GlobalRotScaleTrans', diff --git a/mmdet3d/datasets/kitti_dataset.py b/mmdet3d/datasets/kitti_dataset.py index 632456dbe5..9a864efff2 100644 --- a/mmdet3d/datasets/kitti_dataset.py +++ b/mmdet3d/datasets/kitti_dataset.py @@ -153,13 +153,32 @@ def get_ann_info(self, index): - gt_bboxes (np.ndarray): 2D ground truth bboxes. - gt_labels (np.ndarray): Labels of ground truths. - gt_names (list[str]): Class names of ground truths. - - difficulty (int): kitti difficulty. + - difficulty (int): Difficulty defined by KITTI. + 0, 1, 2 represent xxxxx respectively. """ # Use index to get the annos, thus the evalhook could also use this api info = self.data_infos[index] rect = info['calib']['R0_rect'].astype(np.float32) Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32) + if 'plane' in info: + # convert ground plane to velodyne coordinates + reverse = np.linalg.inv(rect @ Trv2c) + + (plane_norm_cam, + plane_off_cam) = (info['plane'][:3], + -info['plane'][:3] * info['plane'][3]) + plane_norm_lidar = \ + (reverse[:3, :3] @ plane_norm_cam[:, None])[:, 0] + plane_off_lidar = ( + reverse[:3, :3] @ plane_off_cam[:, None][:, 0] + + reverse[:3, 3]) + plane_lidar = np.zeros_like(plane_norm_lidar, shape=(4, )) + plane_lidar[:3] = plane_norm_lidar + plane_lidar[3] = -plane_norm_lidar.T @ plane_off_lidar + else: + plane_lidar = None + difficulty = info['annos']['difficulty'] annos = info['annos'] # we need other objects to avoid collision when sample @@ -195,6 +214,7 @@ def get_ann_info(self, index): bboxes=gt_bboxes, labels=gt_labels, gt_names=gt_names, + plane=plane_lidar, difficulty=difficulty) return anns_results diff --git a/mmdet3d/datasets/pipelines/dbsampler.py b/mmdet3d/datasets/pipelines/dbsampler.py index d2e844e6d1..f0a7074441 100644 --- a/mmdet3d/datasets/pipelines/dbsampler.py +++ b/mmdet3d/datasets/pipelines/dbsampler.py @@ -189,7 +189,7 @@ def filter_by_min_points(db_infos, min_gt_points_dict): db_infos[name] = filtered_infos return db_infos - def sample_all(self, gt_bboxes, gt_labels, img=None): + def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None): """Sampling all categories of bboxes. Args: @@ -264,6 +264,15 @@ def sample_all(self, gt_bboxes, gt_labels, img=None): gt_labels = np.array([self.cat2label[s['name']] for s in sampled], dtype=np.long) + + if ground_plane is not None: + xyz = sampled_gt_bboxes[:, :3] + dz = (ground_plane[:3][None, :] * + xyz).sum(-1) + ground_plane[3] + sampled_gt_bboxes[:, 2] -= dz + for i, s_points in enumerate(s_points_list): + s_points.tensor[:, 2].sub_(dz[i]) + ret = { 'gt_labels_3d': gt_labels, diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index f0ca81f353..6269dc217b 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -268,14 +268,17 @@ class ObjectSample(object): sample_2d (bool): Whether to also paste 2D image patch to the images This should be true when applying multi-modality cut-and-paste. Defaults to False. + use_ground_plane (bool): Whether to use gound plane to adjust the + 3D labels. """ - def __init__(self, db_sampler, sample_2d=False): + def __init__(self, db_sampler, sample_2d=False, use_ground_plane=False): self.sampler_cfg = db_sampler self.sample_2d = sample_2d if 'type' not in db_sampler.keys(): db_sampler['type'] = 'DataBaseSampler' self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS) + self.use_ground_plane = use_ground_plane @staticmethod def remove_points_in_boxes(points, boxes): @@ -306,6 +309,11 @@ def __call__(self, input_dict): gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_labels_3d = input_dict['gt_labels_3d'] + if self.use_ground_plane and 'plane' in input_dict['ann_info']: + ground_plane = input_dict['ann_info']['plane'] + input_dict['plane'] = ground_plane + else: + ground_plane = None # change to float for blending operation points = input_dict['points'] if self.sample_2d: @@ -319,7 +327,10 @@ def __call__(self, input_dict): img=img) else: sampled_dict = self.db_sampler.sample_all( - gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None) + gt_bboxes_3d.tensor.numpy(), + gt_labels_3d, + img=None, + ground_plane=ground_plane) if sampled_dict is not None: sampled_gt_bboxes_3d = sampled_dict['gt_bboxes_3d'] diff --git a/mmdet3d/models/voxel_encoders/pillar_encoder.py b/mmdet3d/models/voxel_encoders/pillar_encoder.py index 45b8d53886..c91cf282ad 100644 --- a/mmdet3d/models/voxel_encoders/pillar_encoder.py +++ b/mmdet3d/models/voxel_encoders/pillar_encoder.py @@ -15,7 +15,6 @@ class PillarFeatureNet(nn.Module): The network prepares the pillar features and performs forward pass through PFNLayers. - Args: in_channels (int, optional): Number of input features, either x, y, z or x, y, z, r. Defaults to 4. @@ -54,7 +53,7 @@ def __init__(self, if with_cluster_center: in_channels += 3 if with_voxel_center: - in_channels += 2 + in_channels += 3 if with_distance: in_channels += 1 self._with_distance = with_distance @@ -84,8 +83,10 @@ def __init__(self, # Need pillar (voxel) size and x/y offset in order to calculate offset self.vx = voxel_size[0] self.vy = voxel_size[1] + self.vz = voxel_size[2] self.x_offset = self.vx / 2 + point_cloud_range[0] self.y_offset = self.vy / 2 + point_cloud_range[1] + self.z_offset = self.vz / 2 + point_cloud_range[2] self.point_cloud_range = point_cloud_range @force_fp32(out_fp16=True) @@ -97,7 +98,6 @@ def forward(self, features, num_points, coors): (N, M, C). num_points (torch.Tensor): Number of points in each pillar. coors (torch.Tensor): Coordinates of each voxel. - Returns: torch.Tensor: Features of pillars. """ @@ -114,21 +114,27 @@ def forward(self, features, num_points, coors): dtype = features.dtype if self._with_voxel_center: if not self.legacy: - f_center = torch.zeros_like(features[:, :, :2]) + f_center = torch.zeros_like(features[:, :, :3]) f_center[:, :, 0] = features[:, :, 0] - ( coors[:, 3].to(dtype).unsqueeze(1) * self.vx + self.x_offset) f_center[:, :, 1] = features[:, :, 1] - ( coors[:, 2].to(dtype).unsqueeze(1) * self.vy + self.y_offset) + f_center[:, :, 2] = features[:, :, 2] - ( + coors[:, 1].to(dtype).unsqueeze(1) * self.vz + + self.z_offset) else: - f_center = features[:, :, :2] + f_center = features[:, :, :3] f_center[:, :, 0] = f_center[:, :, 0] - ( coors[:, 3].type_as(features).unsqueeze(1) * self.vx + self.x_offset) f_center[:, :, 1] = f_center[:, :, 1] - ( coors[:, 2].type_as(features).unsqueeze(1) * self.vy + self.y_offset) + f_center[:, :, 2] = f_center[:, :, 2] - ( + coors[:, 1].type_as(features).unsqueeze(1) * self.vz + + self.z_offset) features_ls.append(f_center) if self._with_distance: @@ -177,6 +183,8 @@ class DynamicPillarFeatureNet(PillarFeatureNet): Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). mode (str, optional): The mode to gather point features. Options are 'max' or 'avg'. Defaults to 'max'. + legacy (bool, optional): Whether to use the new behavior or + the original behavior. Defaults to True. """ def __init__(self, @@ -188,7 +196,8 @@ def __init__(self, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - mode='max'): + mode='max', + legacy=True): super(DynamicPillarFeatureNet, self).__init__( in_channels, feat_channels, @@ -198,7 +207,8 @@ def __init__(self, voxel_size=voxel_size, point_cloud_range=point_cloud_range, norm_cfg=norm_cfg, - mode=mode) + mode=mode, + legacy=legacy) self.fp16_enabled = False feat_channels = [self.in_channels] + list(feat_channels) pfn_layers = [] diff --git a/mmdet3d/utils/setup_env.py b/mmdet3d/utils/setup_env.py index 282ad491e7..98bcc8853f 100644 --- a/mmdet3d/utils/setup_env.py +++ b/mmdet3d/utils/setup_env.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import cv2 import os import platform import warnings + +import cv2 from torch import multiprocessing as mp diff --git a/tests/test_utils/test_setup_env.py b/tests/test_utils/test_setup_env.py index 08233efa71..0c070c9f0e 100644 --- a/tests/test_utils/test_setup_env.py +++ b/tests/test_utils/test_setup_env.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import cv2 import multiprocessing as mp import os import platform + +import cv2 from mmcv import Config from mmdet3d.utils import setup_multi_processes diff --git a/tools/data_converter/kitti_data_utils.py b/tools/data_converter/kitti_data_utils.py index 206d50d680..8e3dba6f35 100644 --- a/tools/data_converter/kitti_data_utils.py +++ b/tools/data_converter/kitti_data_utils.py @@ -4,6 +4,7 @@ from os import path as osp from pathlib import Path +import mmcv import numpy as np from skimage import io