-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A specialized Winograd Conv2d op #971
base: master
Are you sure you want to change the base?
Changes from all commits
5b4e448
68c251b
2ccc67d
02a3cb1
3d80466
893ca79
6afbf6e
0491858
93c3da7
4e8e0d4
e0e94c4
c8700ca
c5d43a2
00ad37e
4f93d67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#include "common.cuh" | ||
|
||
|
||
#define BC 8 | ||
#define BN 32 | ||
#define BK 64 | ||
#define TW 8 | ||
#define TH 16 | ||
#define BN_p 138 | ||
|
||
__constant__ int access_f_s[2][32]; | ||
__constant__ int access_s[2][32]; | ||
__constant__ int tileid[2][32]; | ||
|
||
|
||
// access_f_s | ||
const int aux[2][32] = { | ||
{0,0,1,1,2,2,3,3,4,4,5,5,6,6, | ||
7,7,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}, | ||
{8,8,9,9,10,10,11,11,12,12,13,13, | ||
14,14,15,15,8,8,9,9,10,10,11,11,12,12, | ||
13,13,14,14,15,15} | ||
}; | ||
// access_s | ||
const int aux2[2][32] = { | ||
{0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,2, | ||
3,2,3,2,3,2,3,2,3,2,3,2,3,2,3}, | ||
{4,5,4,5,4,5,4,5,4,5,4, | ||
5,4,5,4,5,6,7,6,7,6,7,6,7, | ||
6,7,6,7,6,7,6,7} | ||
}; | ||
|
||
const int tid[2][32] = { | ||
{0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29, | ||
0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29}, | ||
{2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31, | ||
2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31} | ||
}; | ||
|
||
|
||
|
||
void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||
void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2995,6 +2995,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { | |
"TIMESTEP_EMBEDDING", | ||
"ARGSORT", | ||
"LEAKY_RELU", | ||
"WINOGRAD_STAGE0", | ||
"WINOGRAD_STAGE1", | ||
|
||
"FLASH_ATTN_EXT", | ||
"FLASH_ATTN_BACK", | ||
|
@@ -3024,7 +3026,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { | |
"OPT_STEP_ADAMW", | ||
}; | ||
|
||
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); | ||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); | ||
|
||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | ||
"none", | ||
|
@@ -3089,6 +3091,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | |
"timestep_embedding(timesteps, dim, max_period)", | ||
"argsort(x)", | ||
"leaky_relu(x)", | ||
"winograd_stage0(x)", | ||
"winograd_stage1(x)", | ||
|
||
"flash_attn_ext(x)", | ||
"flash_attn_back(x)", | ||
|
@@ -3118,7 +3122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | |
"adamw(x)", | ||
}; | ||
|
||
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); | ||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); | ||
|
||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); | ||
|
||
|
@@ -7166,6 +7170,73 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( | |
return result; | ||
} | ||
|
||
|
||
// ggml_winograd | ||
|
||
// a: [OC,IC, 3, 3] | ||
// result: [OC, IC, 16] | ||
struct ggml_tensor * ggml_winograd_stage0( | ||
struct ggml_context * ctx, | ||
struct ggml_tensor * a) { | ||
bool is_node = false; | ||
|
||
if (a->grad) { | ||
is_node = true; | ||
} | ||
Comment on lines
+7181
to
+7185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If #966 is merged first this will need to be removed (should be very straightforward). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look forward to it... |
||
|
||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[3], 4, 4, a->ne[2]); | ||
|
||
result->op = GGML_OP_WINOGRAD_STAGE0; | ||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; | ||
result->src[0] = a; | ||
|
||
return result; | ||
} | ||
|
||
// ggml_winograd | ||
// a: [OC, IC, 4, 4] | ||
// b: [1, IC, IH, IW] | ||
// result: [N, OC, OH, OW] | ||
struct ggml_tensor * ggml_winograd_stage1( | ||
struct ggml_context * ctx, | ||
struct ggml_tensor * a, | ||
struct ggml_tensor * b) { | ||
bool is_node = false; | ||
if (a->grad) { | ||
is_node = true; | ||
} | ||
|
||
int OW = b->ne[0]; | ||
int OH = b->ne[1]; | ||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[0] /* OC */, 1); | ||
|
||
result->op = GGML_OP_WINOGRAD_STAGE1; | ||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; | ||
result->src[0] = a; | ||
result->src[1] = b; | ||
|
||
return result; | ||
} | ||
|
||
struct ggml_tensor * ggml_conv_2d_3x3( | ||
struct ggml_context * ctx, | ||
struct ggml_tensor * a, | ||
struct ggml_tensor * b){ | ||
GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3 | ||
GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image | ||
GGML_ASSERT(b->ne[2] == a->ne[2]); // number of channels must match | ||
if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64 | ||
return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8 | ||
|
||
// struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW] | ||
struct ggml_tensor* W = ggml_winograd_stage0(ctx, a); | ||
struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b); | ||
|
||
return result; | ||
|
||
} | ||
|
||
|
||
// ggml_pool_* | ||
|
||
static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { | ||
|
@@ -15124,6 +15195,23 @@ static void ggml_compute_forward_conv_transpose_1d( | |
} | ||
} | ||
|
||
|
||
static void ggml_compute_forward_winograd_stage0( | ||
const struct ggml_compute_params * params, | ||
struct ggml_tensor * dst) { | ||
|
||
GGML_ASSERT(false && " CPU backend not implemented!"); | ||
return; | ||
} | ||
|
||
static void ggml_compute_forward_winograd_stage1( | ||
const struct ggml_compute_params * params, | ||
struct ggml_tensor * dst) { | ||
|
||
GGML_ASSERT(false && " CPU backend not implemented!"); | ||
return; | ||
} | ||
Comment on lines
+15199
to
+15213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If at all possible a CPU implementation should always be done since it serves both as a fallback and as a reference implementation to test other backends against. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A CPU backend should be done, but I am not sure the benefit of it compared to the current im2col+gemm version. |
||
|
||
// ggml_compute_forward_im2col_f32 | ||
// src0: kernel [OC, IC, KH, KW] | ||
// src1: image [N, IC, IH, IW] | ||
|
@@ -17820,6 +17908,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm | |
{ | ||
ggml_compute_forward_conv_transpose_1d(params, tensor); | ||
} break; | ||
case GGML_OP_WINOGRAD_STAGE0: | ||
{ | ||
ggml_compute_forward_winograd_stage0(params, tensor); | ||
} break; | ||
case GGML_OP_WINOGRAD_STAGE1: | ||
{ | ||
ggml_compute_forward_winograd_stage1(params, tensor); | ||
} break; | ||
case GGML_OP_IM2COL: | ||
{ | ||
ggml_compute_forward_im2col(params, tensor); | ||
|
@@ -18893,6 +18989,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor | |
{ | ||
GGML_ABORT("fatal error"); // TODO: not implemented | ||
} | ||
case GGML_OP_WINOGRAD_STAGE0: | ||
{ | ||
GGML_ABORT("fatal error"); // TODO: not implemented | ||
} | ||
case GGML_OP_WINOGRAD_STAGE1: | ||
{ | ||
GGML_ABORT("fatal error"); // TODO: not implemented | ||
} | ||
case GGML_OP_POOL_1D: | ||
{ | ||
GGML_ABORT("fatal error"); // TODO: not implemented | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in the case of multiple GPUs? Is the constant memory duplicated across GPUs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty ignorant about multi-gpu. I guess they will be duplicated. I don't have a setup to test. Plus, this kernel only works for single GPU, I think.