-
Notifications
You must be signed in to change notification settings - Fork 2
/
minkunet.py
130 lines (119 loc) · 5.22 KB
/
minkunet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import os
import argparse
import numpy as np
from urllib.request import urlretrieve
from collections import defaultdict
try:
import open3d as o3d
except ImportError:
raise ImportError("Please install open3d with `pip install open3d`.")
import torch
import MinkowskiEngine as ME
from models.minkunet import MinkUNet14, MinkUNet18, MinkUNet34, MinkUNet50
from models.common import Timer
# Check if the weights and file exist and download
if not os.path.isfile("1.ply"):
print("Downloading the room ply file...")
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", "1.ply")
parser = argparse.ArgumentParser()
parser.add_argument("--file_name", type=str, default="1.ply")
parser.add_argument("--use_cpu", action="store_true")
def load_file(file_name):
pcd = o3d.io.read_point_cloud(file_name)
coords = np.array(pcd.points)
colors = np.array(pcd.colors)
return coords, colors, pcd
if __name__ == "__main__":
config = parser.parse_args()
device = torch.device(
"cuda" if (torch.cuda.is_available() and not config.use_cpu) else "cpu"
)
print(f"Using {device}")
# Define a model and load the weights
model = MinkUNet18(3, 20).to(device)
model.eval()
print(model)
num_conv_layers = defaultdict(int)
for l in model.modules():
if isinstance(l, ME.MinkowskiConvolution) or isinstance(
l, ME.MinkowskiConvolutionTranspose
):
num_conv_layers[l.kernel_generator.kernel_size[0]] += 1
print(num_conv_layers)
voxel_size = 0.02
timer = Timer()
coords, colors, pcd = load_file(config.file_name)
batch_sizes = [1, 2, 4, 6, 8, 10, 12, 14, 16, 20]
if ME.__version__.split(".")[1] == "5":
# Measure time
for batch_size in batch_sizes:
timer = Timer()
coordinates = ME.utils.batched_coordinates(
[coords / voxel_size for i in range(batch_size)], dtype=torch.float32
)
features = torch.rand(len(coordinates), 3).float()
with torch.no_grad():
for i in range(10):
timer.tic()
# Feed-forward pass and get the prediction
in_field = ME.TensorField(
features=features,
coordinates=coordinates,
quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
# minkowski_algorithm=ME.MinkowskiAlgorithm.MEMORY_EFFICIENT,
allocator_type=ME.GPUMemoryAllocatorType.PYTORCH,
device=device,
)
# Convert to a sparse tensor
sinput = in_field.sparse()
# Output sparse tensor
soutput = model(sinput)
# get the prediction on the input tensor field
out_field = soutput.slice(in_field)
timer.toc()
print(batch_size, soutput.shape, timer.min_time)
elif ME.__version__.split(".")[1] == "4":
# Measure time
for batch_size in batch_sizes:
timer = Timer()
coordinates = ME.utils.batched_coordinates(
[coords / voxel_size for i in range(batch_size)]
)
features = torch.rand(len(coordinates), 3).float()
with torch.no_grad():
for i in range(10):
timer.tic()
# Feed-forward pass and get the prediction
sinput = ME.SparseTensor(features.to(device), coords=coordinates,)
# Output sparse tensor
soutput = model(sinput)
# get the prediction on the input tensor field
out_field = soutput.slice(sinput)
timer.toc()
print(batch_size, timer.min_time)
else:
raise NotImplementedError