forked from ayanc/mdepth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
postMAP.cu
91 lines (64 loc) · 1.96 KB
/
postMAP.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/*
--Ayan Chakrabarti <ayanc@ttic.edu>
*/
#include "mex.h"
#include "gpu/mxGPUArray.h"
#include <stdint.h>
#define F float
#define NUMT 1024
void __global__ postMAP(F * der, F * pred, F * bins, F beta,
int W, int H, int K, int B, int crop) {
int i,j,x,y,k,W2,H2;
F btp1,brat, dmin, cmin, dj, cj, dcur;
btp1 = 1.0 + beta; brat = beta / btp1;
W2 = W + 2*crop;
H2 = H + 2*crop;
for (i = blockIdx.x * blockDim.x + threadIdx.x;
i < W*H*K;
i += blockDim.x * gridDim.x) {
k = i/(W*H); x = i%(W*H); y = x%H; x = x/H;
cmin = pred[y+x*H+k*W*H];
dmin = bins[k];
if(beta > 0.0) {
dcur = der[(y+crop)+(x+crop)*H2+k*W2*H2];
cmin = cmin + brat*(dmin-dcur)*(dmin-dcur);
dmin = (dmin + beta*dcur) / btp1;
}
for(j = 1; j < B; j++) {
cj = pred[y+x*H+k*W*H+j*W*H*K];
dj = bins[k+j*K];
if(beta > 0.0) {
cj = cj + brat*(dj-dcur)*(dj-dcur);
dj = (dj + beta*dcur) / btp1;
}
if(cj < cmin) {cmin = cj; dmin = dj;};
}
der[(y+crop)+(x+crop)*H2+k*W2*H2] = dmin;
}
}
F * getGPUmem(const char * name) {
const mxGPUArray * tmp;
F * dptr;
if(!mxIsGPUArray(mexGetVariablePtr("caller",name)))
mexPrintf("%s is not on gpu!\n",name);
tmp = mxGPUCreateFromMxArray(mexGetVariablePtr("caller",name));
dptr = (F*) mxGPUGetDataReadOnly(tmp);
mxGPUDestroyGPUArray(tmp);
return (F*) dptr;
}
/*
function postMAP(beta,crop,H,W,K,B)
X, pred, bins need to be present in the caller workspace.
*/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
F * der, * pred, * bins, beta;
int crop,H,W,K,B;
beta = mxGetScalar(prhs[0]);
crop = (int) mxGetScalar(prhs[1]);
H = (int) mxGetScalar(prhs[2]);
W = (int) mxGetScalar(prhs[3]);
K = (int) mxGetScalar(prhs[4]);
B = (int) mxGetScalar(prhs[5]);
der = getGPUmem("X"); pred = getGPUmem("pred"); bins = getGPUmem("bins");
postMAP<<<(W*H*K+NUMT-1)/NUMT,NUMT>>>(der,pred,bins,beta,W,H,K,B,crop);
}