-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsuperpoint_maker.py
53 lines (40 loc) · 1.71 KB
/
superpoint_maker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# superpoints fetch
import segmentator
import open3d as o3d
import torch
import os
from six.moves import cPickle
import numpy as np
# BRIEF read from pkl
def unpickle_data(file_name, python2_to_3=False):
"""Restore data previously saved with pickle_data()."""
in_file = open(file_name, 'rb')
if python2_to_3:
size = cPickle.load(in_file, encoding='latin1')
else:
size = cPickle.load(in_file)
for _ in range(size):
if python2_to_3:
yield cPickle.load(in_file, encoding='latin1')
else:
yield cPickle.load(in_file)
in_file.close()
def generate_superpoint(data_path, data_path_scannet, split):
scans = unpickle_data(f'{data_path}/{split}_v3scans.pkl')
scans = list(scans)[0]
for scan in scans:
spformer_file = os.path.join(data_path_scannet, split, scan + "_vh_clean_2.ply")
mesh = o3d.io.read_triangle_mesh(spformer_file)
vertices = torch.tensor(np.array(mesh.vertices), dtype=torch.float32)
faces = torch.tensor(np.array(mesh.triangles), dtype=torch.int64)
superpoint = segmentator.segment_mesh(vertices, faces)
select_idx = torch.tensor(scans[scan].choices)
superpoint = torch.index_select(superpoint, 0, select_idx).numpy()
torch.save(superpoint, os.path.join(data_path, "superpoints", split, scan + "_superpoint.pth"))
print("Saving " + scan)
print("Done.")
if __name__ == '__main__':
data_path = r"/path/to/scanrefer" # ScanRefer path
data_path_scannet = r"/path/to/scannetv2" # ScanNetv2 path
split = 'train'
generate_superpoint(data_path=data_path, data_path_scannet=data_path_scannet, split=split)