-
Notifications
You must be signed in to change notification settings - Fork 11
/
wmma_naive.cu
340 lines (261 loc) · 11.2 KB
/
wmma_naive.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
// https://github.com/parallel-forall/code-samples/blob/master/posts/tensor-cores
#include <benchmark/benchmark.h>
#include "gemm/args.hpp"
#include "init/init.hpp"
#include "utils/utils.hpp"
#include <mma.h>
using namespace nvcuda;
#ifndef WARP_SIZE
#define WARP_SIZE (32)
#endif // WARP_SIZE
// MMA matrix tile dimensions. (16, 16, 16), (32, 8, 16), and (8, 32, 16) are
// currently supported.
static const int M = 16;
static const int N = 16;
static const int K = 16;
// Implementation constants.
static const int BLOCK_ROW_TILES = 4;
static const int BLOCK_COL_TILES = 4;
// Performs an MxNxK GEMM (C=alpha*A*B + beta*C) assuming:
// 1) Matrices are packed in memory.
// 2) M, N and K are multiples of 16.
// 3) Neither A nor B are transposed.
// Note: This is NOT a high performance example but is for demonstration
// purposes only
// For a high performance code please use the GEMM provided in cuBLAS.
static __global__ void compute_gemm_naive(const half *__restrict__ a,
const half *__restrict__ b, float *c,
int M_GLOBAL, int N_GLOBAL, int K_GLOBAL,
float alpha, float beta) {
// Leading dimensions. Packed with no transpositions.
int lda = M_GLOBAL;
int ldb = K_GLOBAL;
int ldc = M_GLOBAL;
// Global warp id
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
// Declare the fragments
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, float> acc_frag;
wmma::fragment<wmma::accumulator, M, N, K, float> c_frag;
wmma::fill_fragment(acc_frag, zero<float>());
// Loop over k
for (int i = 0; i < K_GLOBAL; i += K) {
int aRow = warpM * M;
int aCol = i;
int bRow = i;
int bCol = warpN * N;
// Bounds checking
if (aRow < M_GLOBAL && bCol < N_GLOBAL) {
// Load the inputs
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
// Perform the matrix multiplication
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
// Load in the current value of c, scale it by beta, and add this our result
// scaled by alpha
int cRow = warpM * M;
int cCol = warpN * N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
for (int i = 0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
// Store the output
wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
}
}
static __global__ void compute_hgemm_naive(const half *__restrict__ a,
const half *__restrict__ b, half *c,
int M_GLOBAL, int N_GLOBAL, int K_GLOBAL,
half alpha, half beta) {
// Leading dimensions. Packed with no transpositions.
int lda = M_GLOBAL;
int ldb = K_GLOBAL;
int ldc = M_GLOBAL;
// Global warp id
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
// Declare the fragments
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, half> acc_frag;
wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;
wmma::fill_fragment(acc_frag, zero<half>());
// Loop over k
for (int i = 0; i < K_GLOBAL; i += K) {
int aRow = warpM * M;
int aCol = i;
int bRow = i;
int bCol = warpN * N;
// Bounds checking
if (aRow < M_GLOBAL && bCol < N_GLOBAL) {
// Load the inputs
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
// Perform the matrix multiplication
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
// Load in the current value of c, scale it by beta, and add this our result
// scaled by alpha
int cRow = warpM * M;
int cCol = warpN * N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
for (int i = 0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
// Store the output
wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
}
}
static void CUDA_WMMA_GEMM_NAIVE(benchmark::State &state) {
/* if (!has_cuda) { */
/* state.SkipWithError(fmt::format("CUDA_WMMA_GEMM_NAIVE no CUDA device
* found")); */
/* return; */
/* } */
const auto M_GLOBAL = state.range(0);
const auto N_GLOBAL = state.range(1);
const auto K_GLOBAL = state.range(2);
const float alpha = 1.1f;
const float beta = 1.2f;
float *a_fp32;
float *b_fp32;
float *c;
half *a_fp16;
half *b_fp16;
curandGenerator_t gen;
// Use tensor cores
PRINT_IF_ERROR(cudaMalloc((void **) &a_fp32, M_GLOBAL * K_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(cudaMalloc((void **) &b_fp32, K_GLOBAL * N_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(cudaMalloc((void **) &a_fp16, M_GLOBAL * K_GLOBAL * sizeof(half)));
PRINT_IF_ERROR(cudaMalloc((void **) &b_fp16, K_GLOBAL * N_GLOBAL * sizeof(half)));
PRINT_IF_ERROR(cudaMalloc((void **) &c, M_GLOBAL * N_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
PRINT_IF_ERROR(curandSetPseudoRandomGeneratorSeed(gen, 1337ULL));
PRINT_IF_ERROR(curandGenerateUniform(gen, a_fp32, M_GLOBAL * K_GLOBAL));
PRINT_IF_ERROR(curandGenerateUniform(gen, b_fp32, K_GLOBAL * N_GLOBAL));
// curand doesn't currently support fp16 so we generate in fp32 and convert to
// fp16.
PRINT_IF_LAUNCH_ERROR((convertFp32ToFp16<<<(M_GLOBAL * K_GLOBAL + 255) / 256, 256>>>(
a_fp16, a_fp32, M_GLOBAL * K_GLOBAL)));
PRINT_IF_LAUNCH_ERROR((convertFp32ToFp16<<<(K_GLOBAL * N_GLOBAL + 255) / 256, 256>>>(
b_fp16, b_fp32, K_GLOBAL * N_GLOBAL)));
PRINT_IF_ERROR(curandGenerateUniform(gen, c, M_GLOBAL * N_GLOBAL));
PRINT_IF_ERROR(curandDestroyGenerator(gen));
cudaEvent_t start, stop;
PRINT_IF_ERROR(cudaEventCreate(&start));
PRINT_IF_ERROR(cudaEventCreate(&stop));
dim3 gridDim;
dim3 blockDim;
blockDim.x = BLOCK_ROW_TILES * WARP_SIZE;
blockDim.y = BLOCK_COL_TILES;
gridDim.x = (M_GLOBAL + (M * BLOCK_ROW_TILES - 1)) / (M * BLOCK_ROW_TILES);
gridDim.y = (N_GLOBAL + N * blockDim.y - 1) / (N * blockDim.y);
for (auto _ : state) {
PRINT_IF_ERROR(cudaEventRecord(start));
(compute_gemm_naive<<<gridDim, blockDim>>>(a_fp16, b_fp16, c, M_GLOBAL, N_GLOBAL,
K_GLOBAL, alpha, beta));
PRINT_IF_ERROR(cudaEventRecord(stop));
PRINT_IF_ERROR(cudaEventSynchronize(stop));
state.PauseTiming();
float msecTotal = 0.0f;
PRINT_IF_ERROR(cudaEventElapsedTime(&msecTotal, start, stop));
state.SetIterationTime(msecTotal / 1000);
state.ResumeTiming();
}
state.counters.insert({{"M", M_GLOBAL},
{"N", N_GLOBAL},
{"K", K_GLOBAL},
{"num_elements", M * N * K},
{"flops",
{state.iterations() * 2.0 * M_GLOBAL * N_GLOBAL * K_GLOBAL,
benchmark::Counter::kAvgThreadsRate}}});
cudaEventDestroy(start);
cudaEventDestroy(stop);
PRINT_IF_ERROR(cudaFree(a_fp32));
PRINT_IF_ERROR(cudaFree(b_fp32));
PRINT_IF_ERROR(cudaFree(a_fp16));
PRINT_IF_ERROR(cudaFree(b_fp16));
PRINT_IF_ERROR(cudaFree(c));
cudaDeviceReset();
}
static void CUDA_WMMA_HGEMM_NAIVE(benchmark::State &state) {
const auto M_GLOBAL = state.range(0);
const auto N_GLOBAL = state.range(1);
const auto K_GLOBAL = state.range(2);
const __half alpha = approx_float_to_half(1.1f);
const __half beta = approx_float_to_half(1.2f);
float *a_fp32;
float *b_fp32;
float *c_fp32;
half *a_fp16;
half *b_fp16;
half *c_fp16;
curandGenerator_t gen;
// Use tensor cores
PRINT_IF_ERROR(cudaMalloc((void **) &a_fp32, M_GLOBAL * K_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(cudaMalloc((void **) &b_fp32, K_GLOBAL * N_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(cudaMalloc((void **) &a_fp16, M_GLOBAL * K_GLOBAL * sizeof(half)));
PRINT_IF_ERROR(cudaMalloc((void **) &b_fp16, K_GLOBAL * N_GLOBAL * sizeof(half)));
PRINT_IF_ERROR(cudaMalloc((void **) &c_fp32, M_GLOBAL * N_GLOBAL * sizeof(float)));
PRINT_IF_ERROR(cudaMalloc((void **) &c_fp16, M_GLOBAL * N_GLOBAL * sizeof(half)));
PRINT_IF_ERROR(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
PRINT_IF_ERROR(curandSetPseudoRandomGeneratorSeed(gen, 1337ULL));
PRINT_IF_ERROR(curandGenerateUniform(gen, a_fp32, M_GLOBAL * K_GLOBAL));
PRINT_IF_ERROR(curandGenerateUniform(gen, b_fp32, K_GLOBAL * N_GLOBAL));
PRINT_IF_ERROR(curandGenerateUniform(gen, c_fp32, K_GLOBAL * N_GLOBAL));
// curand doesn't currently support fp16 so we generate in fp32 and convert to
// fp16.
PRINT_IF_LAUNCH_ERROR((convertFp32ToFp16<<<(M_GLOBAL * K_GLOBAL + 255) / 256, 256>>>(
a_fp16, a_fp32, M_GLOBAL * K_GLOBAL)));
PRINT_IF_LAUNCH_ERROR((convertFp32ToFp16<<<(K_GLOBAL * N_GLOBAL + 255) / 256, 256>>>(
b_fp16, b_fp32, K_GLOBAL * N_GLOBAL)));
PRINT_IF_LAUNCH_ERROR((convertFp32ToFp16<<<(M_GLOBAL * N_GLOBAL + 255) / 256, 256>>>(
c_fp16, c_fp32, M_GLOBAL * N_GLOBAL)));
PRINT_IF_ERROR(curandDestroyGenerator(gen));
cudaEvent_t start, stop;
PRINT_IF_ERROR(cudaEventCreate(&start));
PRINT_IF_ERROR(cudaEventCreate(&stop));
dim3 gridDim;
dim3 blockDim;
blockDim.x = BLOCK_ROW_TILES * WARP_SIZE;
blockDim.y = BLOCK_COL_TILES;
gridDim.x = (M_GLOBAL + (M * BLOCK_ROW_TILES - 1)) / (M * BLOCK_ROW_TILES);
gridDim.y = (N_GLOBAL + N * blockDim.y - 1) / (N * blockDim.y);
for (auto _ : state) {
PRINT_IF_ERROR(cudaEventRecord(start));
(compute_hgemm_naive<<<gridDim, blockDim>>>(a_fp16, b_fp16, c_fp16, M_GLOBAL, N_GLOBAL,
K_GLOBAL, alpha, beta));
PRINT_IF_ERROR(cudaEventRecord(stop));
PRINT_IF_ERROR(cudaEventSynchronize(stop));
state.PauseTiming();
float msecTotal = 0.0f;
PRINT_IF_ERROR(cudaEventElapsedTime(&msecTotal, start, stop));
state.SetIterationTime(msecTotal / 1000);
state.ResumeTiming();
}
state.counters.insert({{"M", M_GLOBAL},
{"N", N_GLOBAL},
{"K", K_GLOBAL},
{"num_elements", M * N * K},
{"flops",
{state.iterations() * 2.0 * M_GLOBAL * N_GLOBAL * K_GLOBAL,
benchmark::Counter::kAvgThreadsRate}}});
cudaEventDestroy(start);
cudaEventDestroy(stop);
PRINT_IF_ERROR(cudaFree(a_fp32));
PRINT_IF_ERROR(cudaFree(b_fp32));
PRINT_IF_ERROR(cudaFree(a_fp16));
PRINT_IF_ERROR(cudaFree(b_fp16));
PRINT_IF_ERROR(cudaFree(c_fp32));
PRINT_IF_ERROR(cudaFree(c_fp16));
cudaDeviceReset();
}
BENCHMARK(CUDA_WMMA_GEMM_NAIVE)->ARGS()->UseManualTime();
BENCHMARK(CUDA_WMMA_HGEMM_NAIVE)->ARGS()->UseManualTime();