-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrender_utils.cpp
156 lines (129 loc) · 6.1 KB
/
render_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
std::vector<torch::Tensor> infer_t_minmax_cuda(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far);
torch::Tensor infer_n_samples_cuda(torch::Tensor t_min, torch::Tensor t_max, const float stepdist);
std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min);
std::vector<torch::Tensor> sample_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far, const float stepdist);
std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const int N_samples);
torch::Tensor maskcache_lookup_cuda(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift);
std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval);
torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, const float interval);
std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays);
torch::Tensor alpha2weight_backward_cuda(
torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
torch::Tensor grad_weights, torch::Tensor grad_last);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> infer_t_minmax(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
return infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
}
torch::Tensor infer_n_samples(torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
return infer_n_samples_cuda(t_min, t_max, stepdist);
}
std::vector<torch::Tensor> infer_ray_start_dir(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
return infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
}
std::vector<torch::Tensor> sample_pts_on_rays(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far, const float stepdist) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
assert(rays_o.dim()==2);
assert(rays_o.size(1)==3);
return sample_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist);
}
std::vector<torch::Tensor> sample_ndc_pts_on_rays(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const int N_samples) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
assert(rays_o.dim()==2);
assert(rays_o.size(1)==3);
return sample_ndc_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, N_samples);
}
torch::Tensor maskcache_lookup(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift) {
CHECK_INPUT(world);
CHECK_INPUT(xyz);
CHECK_INPUT(xyz2ijk_scale);
CHECK_INPUT(xyz2ijk_shift);
assert(world.dim()==3);
assert(xyz.dim()==2);
assert(xyz.size(1)==3);
return maskcache_lookup_cuda(world, xyz, xyz2ijk_scale, xyz2ijk_shift);
}
std::vector<torch::Tensor> raw2alpha(torch::Tensor density, const float shift, const float interval) {
CHECK_INPUT(density);
assert(density.dim()==1);
return raw2alpha_cuda(density, shift, interval);
}
torch::Tensor raw2alpha_backward(torch::Tensor exp, torch::Tensor grad_back, const float interval) {
CHECK_INPUT(exp);
CHECK_INPUT(grad_back);
return raw2alpha_backward_cuda(exp, grad_back, interval);
}
std::vector<torch::Tensor> alpha2weight(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
CHECK_INPUT(alpha);
CHECK_INPUT(ray_id);
assert(alpha.dim()==1);
assert(ray_id.dim()==1);
assert(alpha.sizes()==ray_id.sizes());
return alpha2weight_cuda(alpha, ray_id, n_rays);
}
torch::Tensor alpha2weight_backward(
torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
torch::Tensor grad_weights, torch::Tensor grad_last) {
CHECK_INPUT(alpha);
CHECK_INPUT(weight);
CHECK_INPUT(T);
CHECK_INPUT(alphainv_last);
CHECK_INPUT(i_start);
CHECK_INPUT(i_end);
CHECK_INPUT(grad_weights);
CHECK_INPUT(grad_last);
return alpha2weight_backward_cuda(
alpha, weight, T, alphainv_last,
i_start, i_end, n_rays,
grad_weights, grad_last);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("infer_t_minmax", &infer_t_minmax, "Inference t_min and t_max of ray-bbox intersection");
m.def("infer_n_samples", &infer_n_samples, "Inference the number of points to sample on each ray");
m.def("infer_ray_start_dir", &infer_ray_start_dir, "Inference the starting point and shooting direction of each ray");
m.def("sample_pts_on_rays", &sample_pts_on_rays, "Sample points on rays");
m.def("sample_ndc_pts_on_rays", &sample_ndc_pts_on_rays, "Sample points on rays");
m.def("maskcache_lookup", &maskcache_lookup, "Lookup to skip know freespace.");
m.def("raw2alpha", &raw2alpha, "Raw values [-inf, inf] to alpha [0, 1].");
m.def("raw2alpha_backward", &raw2alpha_backward, "Backward pass of the raw to alpha");
m.def("alpha2weight", &alpha2weight, "Per-point alpha to accumulated blending weight");
m.def("alpha2weight_backward", &alpha2weight_backward, "Backward pass of alpha2weight");
}