-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
pointnet.py
300 lines (233 loc) · 8.6 KB
/
pointnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Title: Point cloud classification with PointNet
Author: [David Griffiths](https://dgriffiths3.github.io)
Date created: 2020/05/25
Last modified: 2024/01/09
Description: Implementation of PointNet for ModelNet10 classification.
Accelerator: GPU
"""
"""
# Point cloud classification
"""
"""
## Introduction
Classification, detection and segmentation of unordered 3D point sets i.e. point clouds
is a core problem in computer vision. This example implements the seminal point cloud
deep learning paper [PointNet (Qi et al., 2017)](https://arxiv.org/abs/1612.00593). For a
detailed intoduction on PointNet see [this blog
post](https://medium.com/@luis_gonzales/an-in-depth-look-at-pointnet-111d7efdaa1a).
"""
"""
## Setup
If using colab first install trimesh with `!pip install trimesh`.
"""
import os
import glob
import trimesh
import numpy as np
from tensorflow import data as tf_data
from keras import ops
import keras
from keras import layers
from matplotlib import pyplot as plt
keras.utils.set_random_seed(seed=42)
"""
## Load dataset
We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40
dataset. First download the data:
"""
DATA_DIR = keras.utils.get_file(
"modelnet.zip",
"http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",
extract=True,
)
DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")
"""
We can use the `trimesh` package to read and visualize the `.off` mesh files.
"""
mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))
mesh.show()
"""
To convert a mesh file to a point cloud we first need to sample points on the mesh
surface. `.sample()` performs a uniform random sampling. Here we sample at 2048 locations
and visualize in `matplotlib`.
"""
points = mesh.sample(2048)
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
ax.set_axis_off()
plt.show()
"""
To generate a `tf.data.Dataset()` we need to first parse through the ModelNet data
folders. Each mesh is loaded and sampled into a point cloud before being added to a
standard python list and converted to a `numpy` array. We also store the current
enumerate index value as the object label and use a dictionary to recall this later.
"""
def parse_dataset(num_points=2048):
train_points = []
train_labels = []
test_points = []
test_labels = []
class_map = {}
folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))
for i, folder in enumerate(folders):
print("processing class: {}".format(os.path.basename(folder)))
# store folder name with ID so we can retrieve later
class_map[i] = folder.split("/")[-1]
# gather all files
train_files = glob.glob(os.path.join(folder, "train/*"))
test_files = glob.glob(os.path.join(folder, "test/*"))
for f in train_files:
train_points.append(trimesh.load(f).sample(num_points))
train_labels.append(i)
for f in test_files:
test_points.append(trimesh.load(f).sample(num_points))
test_labels.append(i)
return (
np.array(train_points),
np.array(test_points),
np.array(train_labels),
np.array(test_labels),
class_map,
)
"""
Set the number of points to sample and batch size and parse the dataset. This can take
~5minutes to complete.
"""
NUM_POINTS = 2048
NUM_CLASSES = 10
BATCH_SIZE = 32
train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(
NUM_POINTS
)
"""
Our data can now be read into a `tf.data.Dataset()` object. We set the shuffle buffer
size to the entire size of the dataset as prior to this the data is ordered by class.
Data augmentation is important when working with point cloud data. We create a
augmentation function to jitter and shuffle the train dataset.
"""
def augment(points, label):
# jitter points
points += keras.random.uniform(points.shape, -0.005, 0.005, dtype="float64")
# shuffle points
points = keras.random.shuffle(points)
return points, label
train_size = 0.8
dataset = tf_data.Dataset.from_tensor_slices((train_points, train_labels))
test_dataset = tf_data.Dataset.from_tensor_slices((test_points, test_labels))
train_dataset_size = int(len(dataset) * train_size)
dataset = dataset.shuffle(len(train_points)).map(augment)
test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)
train_dataset = dataset.take(train_dataset_size).batch(BATCH_SIZE)
validation_dataset = dataset.skip(train_dataset_size).batch(BATCH_SIZE)
"""
### Build a model
Each convolution and fully-connected layer (with exception for end layers) consists of
Convolution / Dense -> Batch Normalization -> ReLU Activation.
"""
def conv_bn(x, filters):
x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
def dense_bn(x, filters):
x = layers.Dense(filters)(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
"""
PointNet consists of two core components. The primary MLP network, and the transformer
net (T-net). The T-net aims to learn an affine transformation matrix by its own mini
network. The T-net is used twice. The first time to transform the input features (n, 3)
into a canonical representation. The second is an affine transformation for alignment in
feature space (n, 3). As per the original paper we constrain the transformation to be
close to an orthogonal matrix (i.e. ||X*X^T - I|| = 0).
"""
class OrthogonalRegularizer(keras.regularizers.Regularizer):
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.eye = ops.eye(num_features)
def __call__(self, x):
x = ops.reshape(x, (-1, self.num_features, self.num_features))
xxt = ops.tensordot(x, x, axes=(2, 2))
xxt = ops.reshape(xxt, (-1, self.num_features, self.num_features))
return ops.sum(self.l2reg * ops.square(xxt - self.eye))
"""
We can then define a general function to build T-net layers.
"""
def tnet(inputs, num_features):
# Initialise bias as the identity matrix
bias = keras.initializers.Constant(np.eye(num_features).flatten())
reg = OrthogonalRegularizer(num_features)
x = conv_bn(inputs, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = dense_bn(x, 128)
x = layers.Dense(
num_features * num_features,
kernel_initializer="zeros",
bias_initializer=bias,
activity_regularizer=reg,
)(x)
feat_T = layers.Reshape((num_features, num_features))(x)
# Apply affine transformation to input features
return layers.Dot(axes=(2, 1))([inputs, feat_T])
"""
The main network can be then implemented in the same manner where the t-net mini models
can be dropped in a layers in the graph. Here we replicate the network architecture
published in the original paper but with half the number of weights at each layer as we
are using the smaller 10 class ModelNet dataset.
"""
inputs = keras.Input(shape=(NUM_POINTS, 3))
x = tnet(inputs, 3)
x = conv_bn(x, 32)
x = conv_bn(x, 32)
x = tnet(x, 32)
x = conv_bn(x, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = layers.Dropout(0.3)(x)
x = dense_bn(x, 128)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
model.summary()
"""
### Train model
Once the model is defined it can be trained like any other standard classification model
using `.compile()` and `.fit()`.
"""
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["sparse_categorical_accuracy"],
)
model.fit(train_dataset, epochs=20, validation_data=validation_dataset)
"""
## Visualize predictions
We can use matplotlib to visualize our trained model performance.
"""
data = test_dataset.take(1)
points, labels = list(data)[0]
points = points[:8, ...]
labels = labels[:8, ...]
# run test data through model
preds = model.predict(points)
preds = ops.argmax(preds, -1)
points = points.numpy()
# plot points with predicted class and label
fig = plt.figure(figsize=(15, 10))
for i in range(8):
ax = fig.add_subplot(2, 4, i + 1, projection="3d")
ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2])
ax.set_title(
"pred: {:}, label: {:}".format(
CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]]
)
)
ax.set_axis_off()
plt.show()