forked from IDEA-Research/Grounded-Segment-Anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
142 lines (111 loc) · 6.24 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import torch
import torch.nn as nn
from .models.data_processor import DataProcessor
from .models.mean_vfe import MeanVFE
from .models.spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from .models.voxelnext_head import VoxelNeXtHead
from .utils.image_projection import _proj_voxel_image
from segment_anything import SamPredictor, sam_model_registry
class VoxelNeXt(nn.Module):
def __init__(self, model_cfg):
super().__init__()
point_cloud_range = np.array(model_cfg.POINT_CLOUD_RANGE, dtype=np.float32)
self.data_processor = DataProcessor(
model_cfg.DATA_PROCESSOR, point_cloud_range=point_cloud_range,
training=False, num_point_features=len(model_cfg.USED_FEATURE_LIST)
)
input_channels = model_cfg.get('INPUT_CHANNELS', 5)
grid_size = np.array(model_cfg.get('GRID_SIZE', [1440, 1440, 40]))
class_names = model_cfg.get('CLASS_NAMES')
kernel_size_head = model_cfg.get('KERNEL_SIZE_HEAD', 1)
self.point_cloud_range = torch.Tensor(model_cfg.get('POINT_CLOUD_RANGE', [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]))
self.voxel_size = torch.Tensor(model_cfg.get('VOXEL_SIZE', [0.075, 0.075, 0.2]))
CLASS_NAMES_EACH_HEAD = model_cfg.get('CLASS_NAMES_EACH_HEAD')
SEPARATE_HEAD_CFG = model_cfg.get('SEPARATE_HEAD_CFG')
POST_PROCESSING = model_cfg.get('POST_PROCESSING')
self.voxelization = MeanVFE()
self.backbone_3d = VoxelResBackBone8xVoxelNeXt(input_channels, grid_size)
self.dense_head = VoxelNeXtHead(class_names, self.point_cloud_range, self.voxel_size, kernel_size_head,
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING)
class Model(nn.Module):
def __init__(self, model_cfg, device="cuda"):
super().__init__()
sam_type = model_cfg.get('SAM_TYPE', "vit_b")
sam_checkpoint = model_cfg.get('SAM_CHECKPOINT', "/data/sam_vit_b_01ec64.pth")
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint).to(device=device)
self.sam_predictor = SamPredictor(sam)
voxelnext_checkpoint = model_cfg.get('VOXELNEXT_CHECKPOINT', "/data/voxelnext_nuscenes_kernel1.pth")
model_dict = torch.load(voxelnext_checkpoint)
self.voxelnext = VoxelNeXt(model_cfg).to(device=device)
self.voxelnext.load_state_dict(model_dict)
self.point_features = {}
self.device = device
def image_embedding(self, image):
self.sam_predictor.set_image(image)
def point_embedding(self, data_dict, image_id):
data_dict = self.voxelnext.data_processor.forward(
data_dict=data_dict
)
data_dict['voxels'] = torch.Tensor(data_dict['voxels']).to(self.device)
data_dict['voxel_num_points'] = torch.Tensor(data_dict['voxel_num_points']).to(self.device)
data_dict['voxel_coords'] = torch.Tensor(data_dict['voxel_coords']).to(self.device)
data_dict = self.voxelnext.voxelization(data_dict)
n_voxels = data_dict['voxel_coords'].shape[0]
device = data_dict['voxel_coords'].device
dtype = data_dict['voxel_coords'].dtype
data_dict['voxel_coords'] = torch.cat([torch.zeros((n_voxels, 1), device=device, dtype=dtype), data_dict['voxel_coords']], dim=1)
data_dict['batch_size'] = 1
if not image_id in self.point_features:
data_dict = self.voxelnext.backbone_3d(data_dict)
self.point_features[image_id] = data_dict
else:
data_dict = self.point_features[image_id]
pred_dicts = self.voxelnext.dense_head(data_dict)
voxel_coords = data_dict['out_voxels'][pred_dicts[0]['voxel_ids'].squeeze(-1)] * self.voxelnext.dense_head.feature_map_stride
return pred_dicts, voxel_coords
def generate_3D_box(self, lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=0.1):
device = voxel_coords.device
points_image, depth = _proj_voxel_image(voxel_coords, lidar2img_rt, self.voxelnext.voxel_size.to(device), self.voxelnext.point_cloud_range.to(device))
points = points_image.permute(1, 0).int().cpu().numpy()
selected_voxels = torch.zeros_like(depth).squeeze(0)
for i in range(points.shape[0]):
point = points[i]
if point[0] < 0 or point[1] < 0 or point[0] >= mask.shape[1] or point[1] >= mask.shape[0]:
continue
if mask[point[1], point[0]]:
selected_voxels[i] = 1
mask_extra = (pred_dicts[0]['pred_scores'] > quality_score)
if mask_extra.sum() == 0:
print("no high quality 3D box related.")
return None
selected_voxels *= mask_extra
if selected_voxels.sum() > 0:
selected_box_id = pred_dicts[0]['pred_scores'][selected_voxels.bool()].argmax()
selected_box = pred_dicts[0]['pred_boxes'][selected_voxels.bool()][selected_box_id]
else:
grid_x, grid_y = torch.meshgrid(torch.arange(mask.shape[0]), torch.arange(mask.shape[1]))
mask_x, mask_y = grid_x[mask], grid_y[mask]
mask_center = torch.Tensor([mask_y.float().mean(), mask_x.float().mean()]).to(
pred_dicts[0]['pred_boxes'].device).unsqueeze(1)
dist = ((points_image - mask_center) ** 2).sum(0)
selected_id = dist[mask_extra].argmin()
selected_box = pred_dicts[0]['pred_boxes'][mask_extra][selected_id]
return selected_box
def forward(self, image, point_dict, prompt_point, lidar2img_rt, image_id, quality_score=0.1):
self.image_embedding(image)
pred_dicts, voxel_coords = self.point_embedding(point_dict, image_id)
masks, scores, _ = self.sam_predictor.predict(point_coords=prompt_point, point_labels=np.array([1]))
mask = masks[0]
box3d = self.generate_3D_box(lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=quality_score)
return mask, box3d
if __name__ == '__main__':
cfg_dataset = 'nuscenes_dataset.yaml'
cfg_model = 'config.yaml'
dataset_cfg = cfg_from_yaml_file(cfg_dataset, cfg)
model_cfg = cfg_from_yaml_file(cfg_model, cfg)
nuscenes_dataset = NuScenesDataset(dataset_cfg)
model = Model(model_cfg)
index = 0
data_dict = nuscenes_dataset._get_points(index)
model.point_embedding(data_dict)