Skip to content

Commit

Permalink
add GTDrop / PointDrop 3D augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
poodarchu committed Dec 25, 2023
1 parent f7c6fb4 commit 784e329
Showing 1 changed file with 88 additions and 1 deletion.
89 changes: 88 additions & 1 deletion efg/data/augmentations/extend_3d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging

import numpy as np
import random

from efg.data.registry import PROCESSORS
from efg.data.samplers.gt_database_sampler import DataBaseSampler
from efg.data.utils.misc import _dict_select
from efg.data.utils.voxel_generator import VoxelGenerator
from efg.geometry.box_ops import ( # mask_boxes_outside_range,; mask_boxes_outside_range_bev,
from efg.geometry.box_ops import (
mask_boxes_outside_range_bev_z_bound,
mask_boxes_outside_range_center,
mask_points_by_range,
Expand Down Expand Up @@ -441,3 +442,89 @@ def apply_point_clouds(self, points: np.ndarray, center_offset: np.ndarray) -> n
def apply_coords(self, coords: np.ndarray, center_offset: np.ndarray) -> np.ndarray:
coords[:, :2] -= center_offset
return coords


@PROCESSORS.register()
class PointDrop(AugmentationBase):
def __init__(self, ratio=[0.0, 0.2]):
super().__init__()
self._init(locals())

def __call__(self, points, info):
ratio = random.uniform(self.ratio[0], self.ratio[1])
mask = np.random.choice([0, 1], size=(points.shape[0],), p=[ratio, 1 - ratio])
return points[mask.astype(np.bool)], info


@PROCESSORS.register()
class GTDropByCat(AugmentationBase):
def __init__(
self,
ratio=[0.0, [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]],
categories=[
"car",
"truck",
"construction_vehicle",
"bus",
"trailer",
"barrier",
"motorcycle",
"bicycle",
"pedestrian",
"traffic_cone",
],
):
super().__init__()
self._init(locals())

def __call__(self, points, info):
assert "annotations" in info

gt_boxes = info["annotations"]["gt_boxes"]
gt_names = info["annotations"]["gt_names"]

cat_masks = []
cat_ratios = []
for cati, cat in enumerate(self.categories):
cat_mask = gt_names == cat
cat_boxes = gt_boxes[cat_mask]
cat_ratio = random.uniform(self.ratio[0], self.ratio[1][cati])
cat_ratios.append(cat_ratio)
cat_keep_mask = np.random.choice(
[0, 1],
size=(cat_boxes.shape[0],),
p=[cat_ratio, 1 - cat_ratio],
).astype(np.bool_)
cat_mask[np.nonzero(cat_mask)[0][~cat_keep_mask]] = False
cat_masks.append(cat_mask)

mask = np.zeros(gt_names.shape[0]).astype(np.bool)
for m in cat_masks:
mask = mask | m

_dict_select(info["annotations"], mask)
gt_boxes_to_drop = gt_boxes[~mask]
point_indices = points_in_rbbox(points, gt_boxes_to_drop)

return points[~point_indices.sum(axis=-1).astype(np.bool)], info


@PROCESSORS.register()
class GTDrop(AugmentationBase):
def __init__(self, ratio=[0.0, 0.2]):
super().__init__()
self._init(locals())

def __call__(self, points, info):
assert "annotations" in info

gt_boxes = info["annotations"]["gt_boxes"]
ratio = random.uniform(self.ratio[0], self.ratio[1])
mask = np.random.choice([0, 1], size=(gt_boxes.shape[0],), p=[ratio, 1 - ratio]).astype(np.bool_)

_dict_select(info["annotations"], mask)

gt_boxes_to_drop = gt_boxes[~mask]
point_indices = points_in_rbbox(points, gt_boxes_to_drop)

return points[~point_indices.sum(axis=-1).astype(np.bool)], info

0 comments on commit 784e329

Please sign in to comment.