-
Notifications
You must be signed in to change notification settings - Fork 0
/
check.h
37 lines (33 loc) · 918 Bytes
/
check.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "cuda_fp16.h"
bool check(half *in,
half *Wq, half *Wk,half *Wv,
half *q, half *k, half *v,
half *p, half *s,
half *o, half *Wo,
half *out,
int N, int d)
{
printf("checking...\n");
bool res = true;
half *check_q;
host_malloc(&check_q, N * d);
for(int i = 0; i < N; i++) {
for(int j = 0; j < d; j++) {
check_q[i * d + j] = 0;
for(int k = 0; k < d; k++) {
check_q[i * d + j] += in[i * d + k] * Wq[k * d + j];
}
}
}
for(int i = 0; i < N; i++) {
for(int j = 0; j < d; j++) {
if(fabs((float)check_q[i * d+ j] - (float)q[i * d + j])>1e2) {
res = false;
printf("i = %d, j = %d, check_q = %f, q = %f\n", i, j, (float)check_q[i * d+ j], (float)q[i * d + j]);
goto A;
}
}
}
A:
return res;
}