-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdeform.py
516 lines (438 loc) · 18.9 KB
/
deform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import functools
import math
import os
import time
from tkinter import W
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
import abc
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.no_grid=False
self.no_ds=False
self.no_dr=False
self.no_do=True
self.bounds = 1.6
self.kplanes_config = {
'grid_dimensions': 2,
'input_coordinate_dim': 4,
'output_coordinate_dim': 32,
'resolution': [64, 64, 64, 25]
}
self.multires = [1, 2, 4, 8]
self.no_grid = self.no_grid
self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires)
self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()
def create_net(self):
mlp_out_dim = 0
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
output_dim = self.W
return \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = grid_feature
h = self.feature_out(h)
return h
def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):
if time_emb is None:
return self.forward_static(rays_pts_emb[:,:3])
else:
return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)
def forward_static(self, rays_pts_emb):
grid_feature = self.grid(rays_pts_emb[:,:3])
dx = self.static_mlp(grid_feature)
return rays_pts_emb[:, :3] + dx
def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):
hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()
dx = self.pos_deform(hidden)
pts = rays_pts_emb[:, :3] + dx
if self.no_ds:
scales = scales_emb[:,:3]
else:
ds = self.scales_deform(hidden)
scales = scales_emb[:,:3] + ds
if self.no_dr:
rotations = rotations_emb[:,:4]
else:
dr = self.rotations_deform(hidden)
rotations = rotations_emb[:,:4] + dr
if self.no_do:
opacity = opacity_emb[:,:1]
else:
do = self.opacity_deform(hidden)
opacity = opacity_emb[:,:1] + do
# + do
# print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean())
return pts, scales, rotations, opacity
def get_mlp_parameters(self):
parameter_list = []
for name, param in self.named_parameters():
if "grid" not in name:
parameter_list.append(param)
return parameter_list
def get_grid_parameters(self):
return list(self.grid.parameters() )
# + list(self.timegrid.parameters())
class deform_network(nn.Module):
def __init__(self) :
super(deform_network, self).__init__()
net_width = 64
timebase_pe = 4
defor_depth= 1
posbase_pe= 10
scale_rotation_pe = 2
opacity_pe = 2
timenet_width = 64
timenet_output = 32
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
self.apply(initialize_weights)
# print(self)
def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
if times_sel is not None:
return self.forward_dynamic(point, scales, rotations, opacity, times_sel)
else:
return self.forward_static(point)
def forward_static(self, points):
points = self.deformation_net(points)
return points
def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
# times_emb = poc_fre(times_sel, self.time_poc)
means3D, scales, rotations, opacity = self.deformation_net( point,
scales,
rotations,
opacity,
# times_feature,
times_sel)
return means3D, scales, rotations, opacity
def get_mlp_parameters(self):
return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
def get_grid_parameters(self):
return self.deformation_net.get_grid_parameters()
def initialize_weights(m):
if isinstance(m, nn.Linear):
# init.constant_(m.weight, 0)
init.xavier_uniform_(m.weight,gain=1)
if m.bias is not None:
init.xavier_uniform_(m.weight,gain=1)
# init.constant_(m.bias, 0)
def get_normalized_directions(directions):
"""SH encoding must be in the range [0, 1]
Args:
directions: batch of directions
"""
return (directions + 1.0) / 2.0
def normalize_aabb(pts, aabb):
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
# no batch dimension present, need to add it
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim == 2 or grid_dim == 3:
grid_sampler = F.grid_sample
else:
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
f"implemented for 2 and 3D data.")
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
B, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = grid_sampler(
grid, # [B, feature_dim, reso, ...]
coords, # [B, 1, ..., n, grid_dim]
align_corners=align_corners,
mode='bilinear', padding_mode='border')
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
interp = interp.squeeze() # [B?, n, feature_dim?]
return interp
def init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5):
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
has_time_planes = in_dim == 4
assert grid_nd <= in_dim
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for ci, coo_comb in enumerate(coo_combs):
new_grid_coef = nn.Parameter(torch.empty(
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
))
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def interpolate_ms_features(pts: torch.Tensor,
ms_grids: Collection[Iterable[nn.Module]],
grid_dimensions: int,
concat_features: bool,
num_levels: Optional[int],
) -> torch.Tensor:
coo_combs = list(itertools.combinations(
range(pts.shape[-1]), grid_dimensions)
)
if num_levels is None:
num_levels = len(ms_grids)
multi_scale_interp = [] if concat_features else 0.
grid: nn.ParameterList
for scale_id, grid in enumerate(ms_grids[:num_levels]):
interp_space = 1.
for ci, coo_comb in enumerate(coo_combs):
# interpolate in plane
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
interp_out_plane = (
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
.view(-1, feature_dim)
)
# compute product over planes
interp_space = interp_space * interp_out_plane
# combine over scales
if concat_features:
multi_scale_interp.append(interp_space)
else:
multi_scale_interp = multi_scale_interp + interp_space
if concat_features:
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
return multi_scale_interp
class HexPlaneField(nn.Module):
def __init__(
self,
bounds,
planeconfig,
multires
) -> None:
super().__init__()
aabb = torch.tensor([[bounds,bounds,bounds],
[-bounds,-bounds,-bounds]])
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = [planeconfig]
self.multiscale_res_multipliers = multires
self.concat_features = True
# 1. Init planes
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multiscale_res_multipliers:
# initialize coordinate grid
config = self.grid_config[0].copy()
# Resolution fix: multi-res only on spatial planes
config["resolution"] = [
r * res for r in config["resolution"][:3]
] + config["resolution"][3:]
gp = init_grid_param(
grid_nd=config["grid_dimensions"],
in_dim=config["input_coordinate_dim"],
out_dim=config["output_coordinate_dim"],
reso=config["resolution"],
)
# shape[1] is out-dim - Concatenate over feature len for each scale
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
# print(f"Initialized model grids: {self.grids}")
print("feature_dim:",self.feat_dim)
def set_aabb(self,xyz_max, xyz_min):
aabb = torch.tensor([
xyz_max,
xyz_min
])
self.aabb = nn.Parameter(aabb,requires_grad=True)
print("Voxel Plane: set aabb=",self.aabb)
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features
def compute_plane_tv(t):
batch_size, c, h, w = t.shape
count_h = batch_size * c * (h - 1) * w
count_w = batch_size * c * h * (w - 1)
h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()
w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()
return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg
def compute_plane_smoothness(t):
batch_size, c, h, w = t.shape
# Convolve with a second derivative filter, in the time dimension which is dimension 2
first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w]
second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w]
# Take the L2 norm of the result
return torch.square(torch.abs(second_difference)).mean()
class Regularizer():
def __init__(self, reg_type, initialization):
self.reg_type = reg_type
self.initialization = initialization
self.weight = float(self.initialization)
self.last_reg = None
def step(self, global_step):
pass
def report(self, d):
if self.last_reg is not None:
d[self.reg_type].update(self.last_reg.item())
def regularize(self, *args, **kwargs) -> torch.Tensor:
out = self._regularize(*args, **kwargs) * self.weight
self.last_reg = out.detach()
return out
@abc.abstractmethod
def _regularize(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError()
def __str__(self):
return f"Regularizer({self.reg_type}, weight={self.weight})"
class PlaneTV(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'planeTV-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def step(self, global_step):
pass
def _regularize(self, model, **kwargs):
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]
for grids in multi_res_grids:
if len(grids) == 3:
spatial_grids = [0, 1, 2]
else:
spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
for grid_id in spatial_grids:
total += compute_plane_tv(grids[grid_id])
for grid in grids:
# grid: [1, c, h, w]
total += compute_plane_tv(grid)
return total
class TimeSmoothness(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'time-smooth-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids = [2, 4, 5]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return torch.as_tensor(total)
class L1ProposalNetwork(Regularizer):
def __init__(self, initial_value):
super().__init__('l1-proposal-network', initial_value)
def _regularize(self, model, **kwargs) -> torch.Tensor:
grids = [p.grids for p in model.proposal_networks]
total = 0.0
for pn_grids in grids:
for grid in pn_grids:
total += torch.abs(grid).mean()
return torch.as_tensor(total)
class DepthTV(Regularizer):
def __init__(self, initial_value):
super().__init__('tv-depth', initial_value)
def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
depth = model_out['depth']
tv = compute_plane_tv(
depth.reshape(64, 64)[None, None, :, :]
)
return tv
class L1TimePlanes(Regularizer):
def __init__(self, initial_value, what='field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
super().__init__(f'l1-time-{what[:2]}', initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
# model.grids is 6 x [1, rank * F_dim, reso, reso]
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
continue
else:
# These are the spatiotemporal grids
spatiotemporal_grids = [2, 4, 5]
for grid_id in spatiotemporal_grids:
total += torch.abs(1 - grids[grid_id]).mean()
return torch.as_tensor(total)