From a1c004ef2e056cdeffcd47aaac196883bb123a3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 17:42:55 +0200 Subject: [PATCH 001/121] ggml : add ggml_flash_attn_ext API --- ggml-metal.m | 50 +++++++ ggml-metal.metal | 29 ++++ ggml.c | 298 ++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++ llama.cpp | 80 +++++----- tests/test-backend-ops.cpp | 28 ++++ 6 files changed, 456 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc83f7d9c..6d88d5c36a8ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -511,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc54dbd..b79a1ba5634a7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const half * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index cbf2d4bddddb8..e01d938ceb681 100644 --- a/ggml.c +++ b/ggml.c @@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); diff --git a/ggml.h b/ggml.h index de8162b8135f3..d76fe9d5c48c9 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index d28382f7d47b7..cec23c23f1dce 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4205,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4246,8 +4214,49 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d902c..5693c2197c7c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); From fa7ebcca993ec0d47f6ed6a47a8d5ac4f7407262 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 20:06:26 +0200 Subject: [PATCH 002/121] ggml : fix GQA support in ggml_flash_attn_ext --- ggml-metal.metal | 8 ++++---- ggml.c | 23 +++++++++++++++-------- llama.cpp | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba5634a7..28847794cb5d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ceb681..9cf4784ce4759 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f1dce..d4bebe5203e9a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 003/121] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 249 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 124 ++++++++++-------- tests/test-backend-ops.cpp | 14 +-- 6 files changed, 218 insertions(+), 185 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36a8ad..4d85dd3ddb319 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794cb5d8..a1e1755a3a605 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce4759..e64a328fadb1f 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - } - // scale - ggml_vec_scale_f32(nek1, S, scale); + float s; - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + s = s*scale + mv; - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + const float Mold = M; - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; + float ms = 1.0f; + float vs = 1.0f; - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } + if (s > M) { + M = s; + ms = expf(Mold - M); - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } - assert(sum > 0.0); + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c48c9..7bca02f2a2c48 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef0087..4e6c9f9cc75ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,68 +4356,77 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - struct ggml_tensor * kqv; + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } + // split cached v into n_head heads (transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197c7c5..a56c0d6c59a64 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 004/121] metal : initial implementation --- ggml-metal.m | 75 +++++++++++++------- ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 183 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3ddb319..556c53482a75e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; - - const int nwarps = 4; - - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + const int nwarps = 1; + + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3a605..5986bcb427f4b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fadb1f..10df03c9c619b 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c59a64..51a33c662da56 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 005/121] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482a75e..e67a7c4ef892b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb427f4b..e4e89b5b3f7bf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 006/121] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3f7bf..f3a7efafa6613 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; - - if (s > M) { - M = s; - ms = exp(Mold - M); - - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } + M = max(M, s); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 007/121] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef892b..046643146b3f3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa6613..d97952f2b0871 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 008/121] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 173 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 44 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146b3f3..0b1119c4eb467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b0871..789b19bad6b93 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - if (iq1 >= ne01) { - return; - } + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - const int64_t D4 = D/4; + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; + if (iq2 >= ne02) { + return; } - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; + + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - const half Mold = M; + const half Mold = M; - M = max(M, s); + M = max(M, s); - const half ms = exp(Mold - M); - const half vs = exp(s - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 009/121] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4eb467..abb96d6ec6e44 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad6b93..6fdd7fdad4326 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 010/121] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec6e44..d521df43ab302 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad4326..c9876c1033f1f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 011/121] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43ab302..a60dd779a6f09 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c1033f1f..539e26c91c34a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 012/121] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a6f09..fdfb50d3d03f4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91c34a..919119c8d55af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0, From 1446a12b29f422a0c0040e62c16715a3fb7ce1cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Jan 2024 18:27:54 +0200 Subject: [PATCH 013/121] metal : efficient flash_attn_f16 implementation --- ggml-metal.m | 14 +- ggml-metal.metal | 279 +++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 6 +- 3 files changed, 188 insertions(+), 111 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index fdfb50d3d03f4..7b161c69d5801 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src3 = gf->nodes[i]->src[3]; GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); size_t offs_src2 = 0; size_t offs_src3 = 0; @@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 32; - const int64_t nhptg = 2; // heads per threadgroup + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) - const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 919119c8d55af..9b6ceec4e1066 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*R + tiisg/tph; - const int64_t iq1 = tgpig[0]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; if (iq2 >= ne02) { return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const int64_t D4 = D/4; + const int64_t N4 = N_SIMDWIDTH; + const int64_t L4 = (D4 + N4 - 1)/N4; + const int64_t D8 = D/8; + + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 + + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + + for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + if (iq1 + j < ne01) { + pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + } else { + pq4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + // zero out shared memory + for (int64_t j = 0; j < Q; ++j) { + ps4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + if (tiisg < C) { + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + 0 + tiisg] = 0.0h; + } + } - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + threadgroup_barrier(mem_flags::mem_threadgroup); - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; - const int64_t D4 = D/4; + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; - const uint tiih = tiisg%tph; // thread index in head - const uint hiisg = tiisg/tph; // head index in simdgroup + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; - // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; - } + simdgroup_half8x8 mq[D8]; - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; - } + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[i], pq + i*8, T); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: this can be improved + device const float * mp[Q]; - half S = 0.0h; - half M = -INFINITY; + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - for (int64_t ic = sgitg; ic < ne11; ic += nsg) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { - continue; + for (int64_t j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } } - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - half4 s4 = 0.0h; + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; - } + if (smc == -INFINITY) { + continue; + } + } + + // Q*K^T + { + simdgroup_half8x8 mk; - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - simdgroup_barrier(mem_flags::mem_threadgroup); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - if (tiih == 0) { - half s = 0.0h; + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); -#pragma unroll - for (int64_t i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, T, 0, false); + } } - s = s*scale + mv; + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - const half m = M; + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); - M = max(M, s); + half m = M[j]; - const half ms = exp(m - M); - const half vs = exp(s - M); + M[j] = simd_max(max(M[j], s)); - S = S*ms + vs; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; - } + S[j] = S[j]*ms + simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] *= ms; + } + + ss[j*T + p] = vs; + } + + // (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mqkv, ps + i*8, T, 0, false); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); + } + } } - } - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } } threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - M = max(M0, M1); + M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; - } + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } } } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; - } } simdgroup_barrier(mem_flags::mem_threadgroup); - // dst indices - const int64_t i1 = iq1; - const int64_t i2 = iq2; - const int64_t i3 = iq3; - device float4 * dst4 = (device float4 *) dst; if (sgitg == 0) { - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = 0; i < L4; ++i) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51a33c662da56..41ddfcca5b687 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-4; + return 5e-5; } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, @@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH 014/121] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e1066..785a60e50eba8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); From 432ad04ffaa445a3837b92dce1c03513009ab4ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:47:52 +0200 Subject: [PATCH 015/121] metal : scale and mask in matrix form --- ggml-metal.metal | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 785a60e50eba8..ae8f5caeaa75f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { // skip -INF blocks // TODO: double-check this @@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16( device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + // mqk = mqk*scale + mask + simdgroup_float8x8 mm; + simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } @@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]; half m = M[j]; @@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); } From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH 016/121] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caeaa75f..9ab9e16c3915a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4; From f9ca5dcbe86a10cfa873814d5f754b7c9108f339 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:46:07 +0200 Subject: [PATCH 017/121] llama : avoid ggml_cast, use F32 query --- ggml-metal.m | 4 ++-- ggml-metal.metal | 3 ++- ggml.c | 31 +++++++++++++++++++++++++++---- ggml.h | 4 ++++ llama.cpp | 3 ++- tests/test-backend-ops.cpp | 16 +++++++--------- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d5801..7b6762e6d9158 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3915a..c9e4dcfe99cd4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c619b..5e515c03fdb9d 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a2c48..e2f74412fde1e 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc75ea..550caced4ae57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5b687..db1244876ce06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 6fea843b246409a3c4b26156745a89e4ba01029b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:59:41 +0200 Subject: [PATCH 018/121] metal : add parallel reduce version (disabled) --- ggml-metal.m | 2 +- ggml-metal.metal | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b6762e6d9158..cf7880c822db5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index c9e4dcfe99cd4..6eb2825df558b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - // TODO: try parallel reduce +#if 1 if (sgitg == 0) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16( } } } +#else + // parallel reduce + // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps + { + half S = { 0.0h }; + half M = { -INFINITY }; + + for (int64_t sg = nsg/2; sg > 0; sg /= 2) { + if (sgitg >= sg) { + continue; + } + + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +#endif simdgroup_barrier(mem_flags::mem_threadgroup); From 77f6976a87f6d034cf0f7a77e14a011da7901911 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 13:15:00 +0200 Subject: [PATCH 019/121] metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments --- ggml-metal.m | 12 +-- ggml-metal.metal | 220 ++++++++++++++++++++++------------------------- 2 files changed, 110 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eabc16f416645..a7e126bff5318 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + const int64_t ncpsg = 32; // cache values per simdgroup + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index 6eb2825df558b..b564f014de2b6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); +// ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, @@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iq1 = tgpig[0]*Q; const int64_t D4 = D/4; - const int64_t N4 = N_SIMDWIDTH; - const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t L4 = (D4 + NW - 1)/NW; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; + sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; } else { - pq4[j*T4 + N4*i + tiisg] = 0.0h; + sq4[j*T4 + NW*i + tiisg] = 0.0h; } } + } - // zero out shared memory - for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + N4*i + tiisg] = 0.0h; - } + // zero out lo + for (int64_t i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } + // zero out shared memory SH if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 0 + tiisg] = 0.0h; + ss[j*T + tiisg] = 0.0h; + if (tiisg < Q) { + ss[j*T + C + tiisg] = 0.0h; + } } } @@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; + // load the queries from shared memory into local memory simdgroup_half8x8 mq[D8]; for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], pq + i*8, T); + simdgroup_load(mq[i], sq + i*8, T); } - // TODO: this can be improved - device const float * mp[Q]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - { - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - } + // pointer to the mask + device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix simdgroup_half8x8 mscale(scale); - for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int64_t j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; - smc = simd_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } - + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { simdgroup_half8x8 mk; @@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); @@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask simdgroup_float8x8 mm; - simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } + // used to detect blocks full of -INF + half smax = -INFINITY; + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); const half s = ss[j*T + p]; - half m = M[j]; - + smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); + const half m = M[j]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] *= ms; + // create an 8x8 diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; } + // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - simdgroup_barrier(mem_flags::mem_none); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V + // O = diag(ms)*O { - simdgroup_half8x8 mv; + simdgroup_half8x8 mm; - simdgroup_half8x8 mp[C/8]; - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } + simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mqkv; + simdgroup_multiply(lo[i], mm, lo[i]); + } + } - simdgroup_load(mqkv, ps + i*8, T, 0, false); + // O = O + (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mp; + simdgroup_load(mp, ss + 8*cc, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D8; ++i) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); } - - simdgroup_store(mqkv, ps + i*8, T, 0, false); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; @@ -2227,58 +2224,30 @@ kernel void kernel_flash_attn_ext_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps -#if 1 - if (sgitg == 0) { + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t sg = 1; sg < nsg; ++sg) { - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } + // each simdgroup stores its output to shared memory, reusing sq4 + if (sgitg == sg) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } - } -#else - // parallel reduce - // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps - { - half S = { 0.0h }; - half M = { -INFINITY }; - for (int64_t sg = nsg/2; sg > 0; sg /= 2) { - if (sgitg >= sg) { - continue; - } + threadgroup_barrier(mem_flags::mem_threadgroup); + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; M = max(M0, M1); @@ -2290,28 +2259,47 @@ kernel void kernel_flash_attn_ext_f16( if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - } - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + C, T, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } } } -#endif - simdgroup_barrier(mem_flags::mem_threadgroup); + // store result to shared memory (reuse sq4) + if (sgitg == 0) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } device float4 * dst4 = (device float4 *) dst; + // final rescale with 1/S and store to global memory if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; } } } From ecc466a460abc7ad73df3b22a3e0957170bcf7b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 15:42:57 +0200 Subject: [PATCH 020/121] metal : add tests, fix scaling, support C > 32 --- ggml-metal.m | 6 ++-- ggml-metal.metal | 62 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 ++++++--- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff5318..484ef89398e7a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014de2b6..7b604eb61a177 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; + const half m = M[j]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half m = M[j]; + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const half ms = exp(m - M[j]); - S[j] = S[j]*ms + simd_sum(vs); + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7cf3a6..4093a52f2eef1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3a428a10973a751af72b55b9ef396de9c305c6ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 17:47:22 +0200 Subject: [PATCH 021/121] metal : improve precision --- ggml-metal.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7b604eb61a177..b6b5fd997b93a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16( M[j] = simd_max(max(M[j], s)); } - const half ms = exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); S[j] = S[j]*ms; @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = exp(s - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j] + simd_sum(vs); @@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16( M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); S = S0*ms0 + S1*ms1; From 8612864108760897261d0d10101f68355899b03f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 18:10:16 +0200 Subject: [PATCH 022/121] ggml : fix f16 mad --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 6bba840d93d0c..fc0886aecf5a1 100644 --- a/ggml.c +++ b/ggml.c @@ -1344,12 +1344,12 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const // leftovers for (int i = np; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else // scalar for (int i = 0; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #endif } From 134c81c78dfdeaca988ea2505cc6f0c0aec2d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 22:23:40 +0200 Subject: [PATCH 023/121] metal : minor --- ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b6b5fd997b93a..ad6a4a318f4c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { - simdgroup_half8x8 mk; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } @@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { @@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { - simdgroup_half8x8 mv; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mp; simdgroup_load(mp, ss + 8*cc, T, 0, false); @@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < D8; ++i) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mv; simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); From 1db22d7032fd55a612e400164cb70ad238bbc055 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 23:08:31 +0200 Subject: [PATCH 024/121] metal : support Q > 8 --- examples/batched-bench/batched-bench.cpp | 2 +- ggml-metal.m | 7 ++- ggml-metal.metal | 80 +++++++++++++++--------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267401c..4992b57f6f9db 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.mul_mat_q = mmq; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index ef799ef57b643..a0dd1d0df5bcb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index ad6a4a318f4c3..08c000cc4c027 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; + const int64_t Q8 = Q/8; const int64_t NW = N_SIMDWIDTH; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) @@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } } // zero out shared memory SH @@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; + simdgroup_half8x8 mq[Q8][D8]; - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } } const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; @@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } } // mqk = mqk*scale + mask - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_float8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk, ss + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } } } @@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms; - // create an 8x8 diagonal matrix for rescaling the output + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } @@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mp; - simdgroup_load(mp, ss + 8*cc, T, 0, false); + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t i = 0; i < D8; ++i) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_half8x8 mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } } } } @@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } @@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, T, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); } } } @@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } From 4794821a31d5778b3398b8375d29fa63a539c8c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 16:44:55 +0200 Subject: [PATCH 025/121] tests : add ATTN tests --- tests/test-backend-ops.cpp | 70 +++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea4d21..0ce498e9e7dd4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From abeaf0d90ee82096a0aba20785f1e37bd1f3aa41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:12:24 +0200 Subject: [PATCH 026/121] metal : disable buffer allocation logs --- ggml-metal.m | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a0dd1d0df5bcb..a637f04875dbe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2421,10 +2421,13 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2434,10 +2437,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2471,8 +2479,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2549,7 +2556,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2572,7 +2579,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2581,8 +2589,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH 027/121] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f04875dbe..4b5fd0bb8fc58 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4c027..be059d78f505a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa485c1..57271a1ad43e3 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e7dd4..f57e8ab1a853e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 5fcb9c1c5af108056c8ad51fc1995de9d7707d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 19:46:22 +0200 Subject: [PATCH 028/121] metal : faster inner loop for C == 32 --- ggml-metal.metal | 59 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index be059d78f505a..db4c7cfde0037 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2164,34 +2164,59 @@ kernel void kernel_flash_attn_ext_f16( half smax = -INFINITY; // online softmax - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - for (int64_t p = tiisg; p < C; p += NW) { + const half m = M[j]; const half s = ss[j*T + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms; + S[j] = S[j]*ms + simd_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - S[j] = S[j] + simd_sum(vs); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } } From d073e4f93337560e552f0d3de4b6b07bf13ef3f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:45:32 +0200 Subject: [PATCH 029/121] metal : fix array initialization --- ggml-metal.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index db4c7cfde0037..41f6169de8abd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2084,8 +2084,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { 0.0h }; - half M[Q] = { -INFINITY }; + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape const int64_t ne22 = ne12; From 78df5527e4e9eafb181200384fbed80c8116042e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:46:49 +0200 Subject: [PATCH 030/121] tests : ifdef --- tests/test-backend-ops.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f57e8ab1a853e..07182c6d8aa63 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,6 +1726,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); +#if 0 for (int hs : { 64, 80, 96, 112, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -1736,6 +1737,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 8ad92dc1ec9aa6549c68900daa7ab93b57fa3ae5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 19:17:16 +0200 Subject: [PATCH 031/121] ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext --- ggml-cuda.cu | 20 +++++++++---------- ggml-metal.m | 6 ++++++ ggml-metal.metal | 40 ++++++++++++++++++-------------------- ggml.c | 13 +++++++++---- ggml.h | 12 +++++++----- llama.cpp | 40 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 10 +++++----- 7 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e565957421795..c57a031e4060c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f); } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); } @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 15e5568f960f1..e00069624551f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute( id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index b2e40715d4f2d..04c1aaf9cdfb9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -349,9 +349,9 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -366,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float lmax = -INFINITY; @@ -435,14 +435,14 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask - device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); + simdgroup_half8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q8; ++j) { - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); diff --git a/ggml.c b/ggml.c index 466a8cdec3c9d..59a4c05a12ffe 100644 --- a/ggml.c +++ b/ggml.c @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } #ifndef NDEBUG @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(V16, 0, D*sizeof(ggml_fp16_t)); - const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? mp[ic] : 0.0f; + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } diff --git a/ggml.h b/ggml.h index a83ff8035f9ea..74ce1abd4d500 100644 --- a/ggml.h +++ b/ggml.h @@ -1646,11 +1646,13 @@ extern "C" { struct ggml_tensor * v, bool masked); - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch, 1, 1] - // res: [n_embd, n_head, n_batch, 1] !! permuted !! +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 1f8ecc19b4e0c..fe25839669efc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4721,7 +4721,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -4905,7 +4905,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5026,7 +5026,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5148,7 +5148,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -5245,7 +5245,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { @@ -5448,7 +5448,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5538,7 +5538,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -5631,7 +5631,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5731,7 +5731,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5854,7 +5854,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5968,7 +5968,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6089,7 +6089,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6211,7 +6211,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6318,7 +6318,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -6416,7 +6416,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6524,7 +6524,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model( ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0f31c00f9672c..b1b30b91c9c6b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } @@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } @@ -1506,7 +1506,7 @@ struct test_attn : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); struct ggml_tensor * cur; @@ -1793,7 +1793,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); From 910b15bb4006409fe24b41da171cc562cdb1f3a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:41:02 +0200 Subject: [PATCH 032/121] ggml : fix ggml_soft_max mask requirement --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 59a4c05a12ffe..ebd9c6b341080 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; From 2e460137490a4e002a60a60aed052e90179bb65b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:47:20 +0200 Subject: [PATCH 033/121] cuda : fix soft_max to use correct mask size --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c57a031e4060c..15fc6154f7508 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9064,7 +9064,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); From 5a19a9f6d0899becbc71a19454a27c0225edddf7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 19:47:11 +0200 Subject: [PATCH 034/121] cuda : add flash_attn kernel (wip) --- ggml-cuda.cu | 735 ++++++++++++++++++++++++++++++++++++++++++++++++++- llama.cpp | 3 +- 2 files changed, 735 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 15fc6154f7508..60d228a61660f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -108,6 +108,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ half warp_reduce_sum(half x) { +#if __CUDA_ARCH__ >= CC_VOLTA +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); + } + return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = head_dim * seq_len; + const int s = blockIdx.x % seq_len; + + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); + + // QK^T + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + + float tmp = 0.0f; + for(int d = 0; d < head_dim; d++) { + tmp += k[key_offset + d] * q[query_offset + d]; + } + S[is] = tmp * kq_scale; + } + __syncthreads(); + + float max_val = -INFINITY; + // get the max + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; + } + __syncthreads(); + + sum = warp_reduce_sum(sum); + { // softmax sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] *= inv_sum; + } + __syncthreads(); + + // softmax(QK^T)V + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; + } + kqv[dst_index] = temp; + } +} + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D2 = D/2; + const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S[Q]; + half M[Q]; + + for(int i = 0; i < Q; i++) { + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + // Q*K^T + { + for (int cc = 0; cc < C/16; ++cc) { + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + if(mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half smax = __float2half(-INFINITY); + + // online softmax + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + S[j] = S[j]*ms + warp_reduce_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } + + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms + warp_reduce_sum(ls); + } + } + + // skip -INF blocks + if (__hisinf(smax)) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int64_t i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mk[D16]; + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } + + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); + + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = __hmax(M0, M1); + + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst, im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 16: + flash_attn_ext_f16<16, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } @@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/llama.cpp b/llama.cpp index fe25839669efc..2330efff57bd3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); From 56e45a239e1d5a871009aa162b7ba99c93c40b62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:16:32 +0200 Subject: [PATCH 035/121] metal : optimize softmax for C > 32 --- ggml-metal.metal | 16 +++++++++++----- tests/test-backend-ops.cpp | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 04c1aaf9cdfb9..3d5d762d16d99 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + smax = max(smax, s); + M[j] = max(M[j], s); } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + smax = simd_max(smax); + M[j] = simd_max(M[j]); - S[j] = S[j]*ms; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0h; + for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j] + simd_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + simd_sum(ls); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c9c6b..2ab5354069853 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,18 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); From cda5a60a41c669f233a943e1182cd6625f61a924 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:53:29 +0200 Subject: [PATCH 036/121] metal : optimize softmax --- ggml-metal.m | 5 +++-- ggml-metal.metal | 34 +++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e00069624551f..2bbb6d17a36ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute( const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d5d762d16d99..d9a536ae87109 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16( // online softmax if (C == 32) { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } } else { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const half m = M[j]; @@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); // local sum half ls = 0.0h; @@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } - S[j] = S[j]*ms + simd_sum(ls); + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } } From c6769b942229a9e634965e6215651b8d4cf02403 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 21:24:26 +0200 Subject: [PATCH 037/121] tests : minor fix --- tests/test-backend-ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2ab5354069853..727f2ea06a5d7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -578,6 +578,7 @@ struct test_case { } #else int n_nodes = gf->n_nodes; + n_runs = 1000; for (int i = 1; i < n_runs; i++) { for (int j = 0; j < n_nodes; j++) { gf->nodes[gf->n_nodes++] = gf->nodes[j]; From db1f3c482e256398330d44ad22b498ca2cd03161 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 22:08:37 +0200 Subject: [PATCH 038/121] cuda : avoid zeroing fragments --- ggml-cuda.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 098b55e073c12..7130209e702d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); From 12eaa22628740e388789081ccba93159c1b0b412 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 11:55:38 +0200 Subject: [PATCH 039/121] tests : update dims --- ggml-cuda.cu | 180 ++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 6 +- 2 files changed, 110 insertions(+), 76 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7130209e702d2..2c050c0c44edc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) - { - case 16: - flash_attn_ext_f16<16, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - default: - break; + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 727f2ea06a5d7..9feb5e1fe550e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,7 +572,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From b68a112204c58e2bed334273753211c15acc48e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 15:12:28 +0200 Subject: [PATCH 040/121] cuda : fix __hisinf() result check --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c050c0c44edc..0136fbf28f2a5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; From b150abe83e6f0f8a0cf552d7fc0d8fe9f0f78569 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:17:47 +0200 Subject: [PATCH 041/121] cuda : avoid warp_reduce for smax --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0136fbf28f2a5..c3f24242b350a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16( M[j] = __hmax(M[j], s); } - smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); @@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16( } } + smax = warp_reduce_max(smax); + // skip -INF blocks if (__hisinf(smax) == -1) { continue; From 7c34655b366e14d43f7fc9fa104a9ca7b8f60580 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:39:46 +0200 Subject: [PATCH 042/121] cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) --- ggml-cuda.cu | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c3f24242b350a..558ffb8ac7b56 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { + for (int j = warp_id; j < Q; j += num_warps) { const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { + for (int i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; @@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } From 1f8a5924823aecaa6ab1d5c2ac70ddde1d6c27d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:01:32 +0200 Subject: [PATCH 043/121] cuda : make loops use the same loop values Thanks Johannes again for the tip --- ggml-cuda.cu | 43 +++++++++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558ffb8ac7b56..a3a6c6455017b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += num_warps) { + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int i = lane_id; i < D2; i += NW) { + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16( // zero out shared memory SH for (int j = 0; j < Q; ++j) { - for (int i = lane_id; i < SH; i += NW) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9feb5e1fe550e..e4076b49c180d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 92472ea22ca3eed7f65114b0e6b7de1585930759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:10:01 +0200 Subject: [PATCH 044/121] cuda : unroll some of the loops --- ggml-cuda.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3a6c6455017b..deda4cc706fdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory +#pragma unroll for (int j0 = 0; j0 < Q; j0 += num_warps) { const int j = j0 + warp_id; if (j >= Q) { @@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16( const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); +#pragma unroll for (int i0 = 0; i0 < D2; i0 += NW) { const int i = i0 + lane_id; if (i >= D2) { From c51f27c0dbd70fe8eda6182d61371d6a2dea6fb9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:27:36 +0200 Subject: [PATCH 045/121] cuda : avoid __hisinf branches --- ggml-cuda.cu | 83 ++++++++++++++++++---------------------------------- 1 file changed, 29 insertions(+), 54 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index deda4cc706fdc..4d1fb008c3994 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); - M[i] = __float2half(-INFINITY); + M[i] = CUDART_MIN_DENORM_FP16; } // assume K and V are same shape @@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16( half smax = __float2half(-INFINITY); // online softmax - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; - - const half m = M[j]; - const half s = ss[j*T + p]; - - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + for (int j = 0; j < Q; ++j) { + const half m = M[j]; - S[j] = S[j]*ms + warp_reduce_sum(vs); + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + const half s = ss[j*T + p]; - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - const half s = ss[j*T + p]; + M[j] = warp_reduce_max(M[j]); - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = hexp(m - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; + // local sum + half ls = 0.0f; - const half s = ss[j*T + p]; + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + const half s = ss[j*T + p]; - ls += vs; + const half vs = hexp(s - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + ls += vs; - S[j] = S[j]*ms + warp_reduce_sum(ls); + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } smax = warp_reduce_max(smax); @@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); - half M = __float2half(-INFINITY); + half M = CUDART_MIN_DENORM_FP16; __syncthreads(); @@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); S = S0*ms0 + S1*ms1; From b958151e3f66e17a9bc5131e446a50c5529b4b81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:00:25 +0200 Subject: [PATCH 046/121] cuda : use half2 in softmax --- ggml-cuda.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4d1fb008c3994..1fed9d23e2f47 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16( const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; @@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } M[j] = warp_reduce_max(M[j]); @@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16( } // local sum - half ls = 0.0f; + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - const half vs = hexp(s - M[j]); + const half2 vs = h2exp(s - M2); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss2[j*T2 + p] = vs; } - S[j] = S[j]*ms + warp_reduce_sum(ls); + ls = warp_reduce_sum(ls); + + S[j] = S[j]*ms + ls.x + ls.y; } smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; } From a7b471569bdf4e09e97b2d02c27989b8cb801861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:17:49 +0200 Subject: [PATCH 047/121] cuda : switch to 1 warp for bs > 16 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fed9d23e2f47..c98b551b31e8e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); From 3b1c4e76739031bee3028748e0cd288c148f77b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:36:05 +0200 Subject: [PATCH 048/121] cuda : speed-up reduce part of the kernel --- ggml-cuda.cu | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c98b551b31e8e..67541a61ef716 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = CUDART_MIN_DENORM_FP16; - __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6733,27 +6730,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); const half ms0 = hexp(M0 - M); const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) + GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; From 5b263dd83a5f906eddd10bc044051d7571097043 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:12:20 +0200 Subject: [PATCH 049/121] cuda : unroll Q*K^T loop --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 67541a61ef716..dbd4822396f4e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { +#pragma unroll for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { From e04ff391819e1875beed3e30d9e7592db45e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:57:46 +0200 Subject: [PATCH 050/121] cuda : fix -INF block check --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbd4822396f4e..e51ddc08f764f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { continue; } @@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } + } - // restore zeros + // restore zeros + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } From cfd9732b2e45a442f4f7261ac0b50ec6e0862ab2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:31:55 +0200 Subject: [PATCH 051/121] cuda : simplify softmax --- ggml-cuda.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e51ddc08f764f..25f810cbea86f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q]; + half S = __float2half(0.0f); half M[Q]; for (int i = 0; i < Q; ++i) { - S[i] = __float2half(0.0f); M[i] = CUDART_MIN_DENORM_FP16; } @@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16( M[j] = warp_reduce_max(M[j]); - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - // local sum half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); @@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16( ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + ls.x + ls.y; + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } } smax = warp_reduce_max(smax); @@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int j = 0; j < Q; ++j) { - if (lane_id == 0) { - ss[j*T + 0] = S[j]; + if (lane_id == j) { + ss[j*T + 0] = S; ss[j*T + 1] = M[j]; } } From ef68fac2a8b51e2237234e3d7c6120cade457ce8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:36:58 +0200 Subject: [PATCH 052/121] cuda : fix matrix names --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 25f810cbea86f..d9ab2bd093feb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mk[D16]; + half16x16_b mv[D16]; for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } - half16x16_a mv[Q16]; + half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } } From 1846e92a904ef17a55a3f7e7c2e837f35db2ce7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 09:57:58 +0200 Subject: [PATCH 053/121] cuda : minor --- ggml-cuda.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d9ab2bd093feb..713a6a89acfdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; #endif // based on metal version @@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; + const int C16 = C/16; + const int NW = WARP_SIZE; const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) const int C2 = C/2; + const int D2 = D/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq @@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { #pragma unroll - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); @@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16( // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mv[D16]; @@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int j = 0; j < Q; ++j) { - if (lane_id == j) { - ss[j*T + 0] = S; - ss[j*T + 1] = M[j]; - } + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; } } @@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + switch (Q->ne[0]) { case 64: flash_attn_ext_f16<64, NQPB, NCPW> @@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * default: break; } + + CUDA_CHECK(cudaGetLastError()); } static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From f249c997a8b3f9b129fe825bebd609a362e9ab9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Feb 2024 13:10:24 +0200 Subject: [PATCH 054/121] llama : adapt to F16 KQ_pos --- ggml-cuda.cu | 2 +- ggml.c | 2 +- llama.cpp | 15 ++++++++++----- tests/test-backend-ops.cpp | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c8af51a66d59..5c6159a83f3cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); diff --git a/ggml.c b/ggml.c index efc570db698f2..9a2ae62647364 100644 --- a/ggml.c +++ b/ggml.c @@ -5192,7 +5192,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { diff --git a/llama.cpp b/llama.cpp index 2359ed10aebf0..5aa3a508d7080 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,7 +102,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 -#define LLAMA_FLASH_ATTN +//#define LLAMA_FLASH_ATTN // // logging @@ -4831,6 +4831,11 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; #if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5260,7 +5265,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); // shift the entire K-cache if needed @@ -5804,7 +5809,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { @@ -6043,7 +6048,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -6140,7 +6145,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 912223def6e06..278c57299ce88 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1505,7 +1505,7 @@ struct test_attn : public test_case { struct ggml_tensor * cur; cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); cur = ggml_mul_mat (ctx, v, cur); cur = ggml_permute (ctx, cur, 0, 2, 1, 3); cur = ggml_cont_2d (ctx, cur, hs*nh, nb); From 6aefd11204199c9bd520b8991bab4085cb6fc977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Mar 2024 13:50:54 +0200 Subject: [PATCH 055/121] llama : adapt new models to F16 KQ_mask --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1a099adcba5dc..f2b224cafb2e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7362,7 +7362,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7489,7 +7489,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7724,7 +7724,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { From 58c7f6167c0f6540f0da0386fc65d940e1a16ea5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 20:44:57 +0200 Subject: [PATCH 056/121] ggml : fix F16 store (ARM NEON) --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 09ed9da342885..5715b78ec7bbf 100644 --- a/ggml.c +++ b/ggml.c @@ -874,7 +874,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -900,7 +900,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 3a468e6f9f0c7dff9ed78b0f7a5af069da420606 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:12:17 +0200 Subject: [PATCH 057/121] llama : fix type of KQ_mask and KQ_pos --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77da94960e65c..b80080daf7506 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5810,20 +5810,20 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); } struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv); + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); } struct ggml_tensor * build_inp_mean() { From 09532120e0afc5b0c451fbb79eac558e4561660b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:49:42 +0200 Subject: [PATCH 058/121] ggml : fix CPU soft_max --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 434fb76c3b1b6..7ea1abfcd67c6 100644 --- a/ggml.c +++ b/ggml.c @@ -12302,7 +12302,7 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); From e425810bb6c8e9dd8ef9ff6f606313b6cfa1b607 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Mar 2024 12:21:41 +0200 Subject: [PATCH 059/121] tests : add hs=256 --- ggml.c | 2 +- tests/test-backend-ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 7ea1abfcd67c6..336c3c0ffaade 100644 --- a/ggml.c +++ b/ggml.c @@ -12272,7 +12272,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b8994cecff4b2..50c4b27294fea 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2261,7 +2261,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 128, 64, 80, }) { + for (int hs : { 128, 256, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 6be02b5969ed04fb3ce336702ae3949076bb2bf4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Mar 2024 10:31:52 +0200 Subject: [PATCH 060/121] cuda : fix build --- ggml-cuda.cu | 7 +------ ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++++++++++++++++++------ ggml-cuda/fattn.cuh | 5 +---- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 963588c4f2fad..31bfc43c1febe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_argsort(ctx, dst); break; case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); break; default: return false; } - if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } else { - func(ctx, tensor->src[0], tensor->src[1], tensor); - } - cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 0f135a184b572..bcf27fd794aaf 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,5 +1,33 @@ #include "fattn.cuh" +#include + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; @@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment // based on metal version template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ dst, + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, float scale, int ne00, int ne01, @@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16( #endif } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh index 1b764bc946f8f..ad3ca7a8d8e4d 100644 --- a/ggml-cuda/fattn.cuh +++ b/ggml-cuda/fattn.cuh @@ -1,6 +1,3 @@ #include "common.cuh" -void ggml_cuda_flash_attn_ext( - ggml_backend_cuda_context & ctx, - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, - const ggml_tensor * mask, ggml_tensor * KQV); +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 57c03b78b6bca2049e43905e808666b04304d0fd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:29:06 +0200 Subject: [PATCH 061/121] metal : improve perf via smaller int registers --- ggml-metal.metal | 149 ++++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index dc6ad31417c51..27eeb3932ff1a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2061,11 +2061,11 @@ typedef void (flash_attn_ext_f16_t)( constant int64_t & ne3, constant float & scale, threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup @@ -2099,25 +2099,25 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne3, constant float & scale, threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*Q; - - const int64_t D4 = D/4; - const int64_t D8 = D/8; - const int64_t Q8 = Q/8; - const int64_t NW = N_SIMDWIDTH; - const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - - const int64_t T = D + nsg*SH; // shared memory size per query in (half) - const int64_t T4 = T/4; // shared memory size per query in (half4) + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 @@ -2127,10 +2127,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { + for (short j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { sq4[j*T4 + i] = (half4) q4[i]; } else { @@ -2140,15 +2140,15 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { lo[j][i] = make_filled_simdgroup_matrix(0.0h); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = tiisg; i < SH; i += NW) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { ss[j*T + i] = 0.0h; } } @@ -2160,33 +2160,33 @@ kernel void kernel_flash_attn_ext_f16( half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const short ne22 = ne12; + const short ne23 = ne13; - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory simdgroup_half8x8 mq[Q8][D8]; - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); } } @@ -2199,28 +2199,33 @@ kernel void kernel_flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + // Q*K^T { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk[Q8]; - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); @@ -2237,8 +2242,8 @@ kernel void kernel_flash_attn_ext_f16( if (C == 32) { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; + for (short j = 0; j < Q; ++j) { + const short p = tiisg; const half m = M[j]; const half s = ss[j*T + p]; @@ -2262,10 +2267,10 @@ kernel void kernel_flash_attn_ext_f16( } else { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; smax = max(smax, s); @@ -2280,7 +2285,7 @@ kernel void kernel_flash_attn_ext_f16( // local sum half ls = 0.0h; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); @@ -2306,25 +2311,25 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mv; simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); @@ -2336,7 +2341,7 @@ kernel void kernel_flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -2345,7 +2350,7 @@ kernel void kernel_flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < nsg; ++sg) { + for (short sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2353,8 +2358,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2364,7 +2369,7 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -2388,7 +2393,7 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; @@ -2396,7 +2401,7 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); @@ -2408,8 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2419,10 +2424,10 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } From 08e69c50081622d8146e749195939157be2a2207 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:40:11 +0200 Subject: [PATCH 062/121] cuda : adapt soft_max to F16 mask and pos --- ggml-cuda/softmax.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e581c75..8f6dca4d0f9bf 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,7 @@ #include "softmax.cuh" template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const half * src1_d = src1 ? (const half *)src1->data : nullptr; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + half * src2_dd = nullptr; ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_dd = (half *)src2->data; } soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); From 75aa7b4b189a5a2f6518840e84ec489da71e0443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 29 Mar 2024 23:02:39 +0100 Subject: [PATCH 063/121] CUDA: faster FlashAttention, kernel for bs == 1 --- ggml-cuda/fattn.cu | 1357 +++++++++++++++++++++++++++++--------------- 1 file changed, 906 insertions(+), 451 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index bcf27fd794aaf..ccb3c924609a4 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -28,414 +28,416 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -#if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; -#endif - -// based on metal version -template // D head size, Q queries per block, C cache items per block -static __global__ void flash_attn_ext_f16( - const char * __restrict__ q, - const char * __restrict__ k, - const char * __restrict__ v, +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { -#if __CUDA_ARCH__ >= CC_VOLTA - const int warp_id = threadIdx.y; - const int lane_id = threadIdx.x; - - const int num_warps = blockDim.y; // number of warps - const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y; - const int iq1 = blockIdx.x * Q; - - const int D16 = D/16; - const int Q16 = Q/16; - const int C16 = C/16; - - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) - - const int T = D + num_warps*SH; // shared memory size per query in (half) - const int T2 = T/2; // shared memory size per query in (half2) - const int C2 = C/2; - const int D2 = D/2; - - extern __shared__ half __flash_attn_f16_shmem[]; - // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix - half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 - - half16x16_acc zr; - half16x16_acc lo[Q16][D16]; - - // load heads from Q to shared memory + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne31*blockIdx.x; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = D/WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[D]; + KQ[tid] = 0.0f; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -INFINITY; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -INFINITY; + kqsum_shared[threadIdx.x] = 0.0f; + } + + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; #pragma unroll - for (int j0 = 0; j0 < Q; j0 += num_warps) { - const int j = j0 + warp_id; - if (j >= Q) { + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; #pragma unroll - for (int i0 = 0; i0 < D2; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D2) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { break; } - if (iq1 + j < ne01) { - sq2[j*T2 + i] = __float22half2_rn(q2[i]); - } else { - sq2[j*T2 + i] = make_half2(0.0, 0.0); + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; } - } - } - nvcuda::wmma::fill_fragment(zr, 0.0); + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } - // zero out lo - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; } - } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); - // zero out shared memory SH - for (int j = 0; j < Q; ++j) { - for (int i0 = 0; i0 < SH; i0 += NW) { - const int i = i0 + lane_id; - if (i >= SH) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } - ss[j*T + i] = 0.0; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; } } + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); - { - half S = __float2half(0.0f); - half M[Q]; + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; +} - for (int i = 0; i < Q; ++i) { - M[i] = CUDART_MIN_DENORM_FP16; +template // D == head size +__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c; + + constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; + constexpr int nthreads = nwarps*WARP_SIZE; + static_assert(nthreads % D == 0, "nthreads not divisible by D."); + constexpr int tc_vals_per_iter = nwarps*frag_m; + static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + half2 * KQ2 = (half2 *) KQ; + + half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { + const int i = i0 + tid; + if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { + break; } - // assume K and V are same shape - const int ne22 = ne12; - const int ne23 = ne13; - - const int nb21 = nb11; - const int nb22 = nb12; - const int nb23 = nb13; - - // broadcast - const int rk2 = ne02/ne12; - const int rk3 = ne03/ne13; - - const int rv2 = ne02/ne22; - const int rv3 = ne03/ne23; + VKQ2[i] = make_half2(0.0f, 0.0f); + } - // k indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { + const int j = j0 + tid/D; + const int i = tid % D; + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + __syncthreads(); - // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); - } + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } + } - // pointer to the mask - const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + __syncthreads(); - // prepare diagonal scale matrix - half16x16_b mscale; - for (int i = 0; i < 16; ++i) { - ss[i*T + i] = __float2half(scale); - } - nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { - const int ic = ic0 + warp_id*C; - if (ic >= ne11) { - break; + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + frag_c KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - - // Q*K^T - { + if (has_valid_data) { #pragma unroll - for (int cc = 0; cc < C16; ++cc) { - half16x16_acc mqk[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::fill_fragment(mqk[j], 0); - } - - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (int i = 0; i < D16; ++i) { - half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); - } - } - - // mqk = mqk*scale + mask - for (int j = 0; j < Q16; ++j) { - half16x16_a mqka; - half16x16_acc mm; - - if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); - } - - // convert accumulator to matrix_a - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); - - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + } + } - // used to detect blocks full of -INF - half2 smax = make_half2(-INFINITY, -INFINITY); - - // online softmax - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; + __syncthreads(); - const half2 s = ss2[j*T2 + p]; + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } - smax = __hmax2(smax, s); - M[j] = __hmax(M[j], __hmax(s.x, s.y)); + half2 KQ_max_new = KQ_max[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + KQ_max[j0/nwarps] = KQ_max_new; - M[j] = warp_reduce_max(M[j]); - - // local sum - half2 ls = make_half2(0.0f, 0.0f); - half2 M2 = make_half2(M[j], M[j]); - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; - - const half2 s = ss2[j*T2 + p]; - - const half2 vs = h2exp(s - M2); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss2[j*T2 + p] = vs; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } - - ls = warp_reduce_sum(ls); - - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - - S = S*ms + ls.x + ls.y; + if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { + break; } - } - smax = warp_reduce_max(smax); - - // skip -INF blocks - if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { - continue; + half2 val = KQ2[j*(D_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + val = h2exp(val - KQ_max[j0/nwarps]); + KQ_rowsum_add += val; + KQ2[j*(D_padded/2) + k] = val; } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - // O = diag(ms)*O - for (int j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b lob; - - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + } - for (int i = 0; i < D16; ++i) { - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + __syncthreads(); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); - } + frag_b KQ_b[D/16][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); } + } - // restore zeros - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); } - // O = O + (Q*K^T)*V - { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); - } - - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); - } + #pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); - } - } + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); } } } - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (lane_id < Q) { - ss[lane_id*T + 0] = S; - ss[lane_id*T + 1] = M[lane_id]; - } - } - - // reduce the warps sequentially - for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); - // each simdgroup stores its output to shared memory, reusing sq - if (warp_id == sg) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, + VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); } } __syncthreads(); - // the first simdgroup accumulates the results from the other simdgroups - if (warp_id == 0) { - for (int j = lane_id; j < Q; j += NW) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; - - const half M = __hmax(M0, M1); - - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); - - const half S = S0*ms0 + S1*ms1; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; - - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int j = 0; j < Q16; ++j) { - half16x16_a ms0; - half16x16_a ms1; - half16x16_b t; - half16x16_acc t2; - - nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, zr); - - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); - - nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; } } - } - // store result to shared memory (reuse sq) - if (warp_id == 0) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } - } + __syncthreads(); } - // final rescale with 1/S and store to global memory - if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; - - for (int i0 = 0; i0 < D; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D) { - break; - } - - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } -#else - NO_DEVICE_CODE; -#endif } + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -461,133 +463,586 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); -#define NQPB 16 -#define NCPW 128 - - const int nqpb = NQPB; // queries per block - const int ncpw = NCPW; // cache values per warp (does not work for other values) - - GGML_ASSERT(NQPB <= 32); - - const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? - // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; - - dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); - dim3 block_dim(32, nwarps, 1); - - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + const int nwarps = Q->ne[0] / WARP_SIZE; + const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const int shmem = 0; + switch (Q->ne[0]) { + case 64: + flash_attn_vec_ext_f16<64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 80: + // flash_attn_vec_ext_f16<80> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 96: + flash_attn_vec_ext_f16<96> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 112: + // flash_attn_vec_ext_f16<112> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 128: + flash_attn_vec_ext_f16<128> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_vec_ext_f16<256> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + return; + } - // increase shared memory limit to 96KB - //const size_t shmem_max = 96*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } + const int frag_m = cols_per_block == 8 ? 32 : 16; + const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const size_t shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 96: - flash_attn_ext_f16<96, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 112: - flash_attn_ext_f16<112, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 256: - flash_attn_ext_f16<256, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + case 64: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<64, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<64, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<64, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 80: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<80, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<80, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<80, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<80, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 96: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<96, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<96, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<96, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<96, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 112: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<112, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<112, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<112, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<112, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 128: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<128, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<128, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<128, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<128, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 256: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<256, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<256, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<256, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 64: + // flash_attn_ext_f16<256, 64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; default: + GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); } From d59ac670bf92f18ba9db44f37fef93b002528ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 09:19:19 +0100 Subject: [PATCH 064/121] 16 cols for Phi-2 --- ggml-cuda/fattn.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ccb3c924609a4..d34924c3173e6 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,15 +579,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; + int cols_per_block = 16; + if (Q->ne[0] % 32 == 0) { + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From 81da919864831948f292aeb0a5bd11eb5868bdb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 10:34:09 +0100 Subject: [PATCH 065/121] no vec for hs, no hs==256 ncols==32 for Volta --- ggml-cuda/common.cuh | 1 + ggml-cuda/fattn.cu | 72 ++++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 33c8ed1da8d83..c245dd6ac009a 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -141,6 +141,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d34924c3173e6..43b9a9f4a3d11 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { const int nwarps = Q->ne[0] / WARP_SIZE; const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_vec_ext_f16<64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 64: + // flash_attn_vec_ext_f16<64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 80: // flash_attn_vec_ext_f16<80> // <<>> ( @@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // ); // break; - case 96: - flash_attn_vec_ext_f16<96> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 96: + // flash_attn_vec_ext_f16<96> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 112: // flash_attn_vec_ext_f16<112> // <<>> ( @@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (Q->ne[0] % 32 == 0) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { cols_per_block = 64; - } else if (Q->ne[1] >= 64) { + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; From 269374ed818dde2b267307d62be7ba59385aebfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 16:01:27 +0200 Subject: [PATCH 066/121] adjust kernel selection logic --- ggml-cuda/fattn.cu | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 43b9a9f4a3d11..f2c46008633d9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block = 16; - if (Q->ne[0] % 32 == 0) { - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From cca6d027a323b071d951f702ab3ede0d1937bb6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 18:39:02 +0200 Subject: [PATCH 067/121] 4 warps, 256 stride for all D --- ggml-cuda/fattn.cu | 633 +++++++++++---------------------------------- 1 file changed, 147 insertions(+), 486 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f2c46008633d9..aa85244fc52ce 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "fattn.cuh" #include @@ -176,8 +177,10 @@ static __global__ void flash_attn_vec_ext_f16( dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; } -template // D == head size -__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -206,6 +209,7 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; @@ -215,14 +219,13 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_b; typedef nvcuda::wmma::fragment frag_c; - constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; - constexpr int nthreads = nwarps*WARP_SIZE; - static_assert(nthreads % D == 0, "nthreads not divisible by D."); - constexpr int tc_vals_per_iter = nwarps*frag_m; - static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < nthreads); - constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); @@ -235,31 +238,43 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; - __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll - for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { - const int i = i0 + tid; - if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { - break; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); } - - VKQ2[i] = make_half2(0.0f, 0.0f); } // Convert Q to half and apply scale, temporarily store in KQ: #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { - const int j = j0 + tid/D; - const int i = tid % D; - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } __syncthreads(); @@ -276,31 +291,27 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { - const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { frag_c KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - if (has_valid_data) { #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -311,18 +322,12 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); @@ -330,20 +335,14 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { - break; - } - half2 val = KQ2[j*(D_padded/2) + k]; + half2 val = KQ2[j*(kqs_padded/2) + k]; val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); val = h2exp(val - KQ_max[j0/nwarps]); KQ_rowsum_add += val; - KQ2[j*(D_padded/2) + k] = val; + KQ2[j*(kqs_padded/2) + k] = val; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); @@ -353,47 +352,46 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[D/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); } } - frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; + frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { - #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); } - #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); - #pragma unroll + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } __syncthreads(); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, - VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); } } @@ -403,16 +401,19 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -422,7 +423,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + if (ncols*blockIdx.x + j >= ne01) { return; } const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); @@ -437,6 +438,50 @@ static __global__ void flash_attn_ext_f16( } } +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -580,7 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { cols_per_block = 64; } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; @@ -590,451 +635,67 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; switch (Q->ne[0]) { case 64: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<64, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<64, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<64, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(64, 8, nwarps); + FATTN_SWITCH_CASE(64, 16, nwarps); + FATTN_SWITCH_CASE(64, 32, nwarps); + FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 80: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<80, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<80, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<80, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<80, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(80, 8, nwarps); + FATTN_SWITCH_CASE(80, 16, nwarps); + FATTN_SWITCH_CASE(80, 32, nwarps); + // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 96: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<96, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<96, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<96, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<96, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(96, 8, nwarps); + FATTN_SWITCH_CASE(96, 16, nwarps); + FATTN_SWITCH_CASE(96, 32, nwarps); + FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 112: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<112, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<112, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<112, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<112, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(112, 8, nwarps); + FATTN_SWITCH_CASE(112, 16, nwarps); + FATTN_SWITCH_CASE(112, 32, nwarps); + // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 128: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<128, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<128, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<128, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<128, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(128, 8, nwarps); + FATTN_SWITCH_CASE(128, 16, nwarps); + FATTN_SWITCH_CASE(128, 32, nwarps); + // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 256: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<256, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<256, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<256, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - // case 64: - // flash_attn_ext_f16<256, 64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + FATTN_SWITCH_CASE(256, 8, nwarps); + FATTN_SWITCH_CASE(256, 16, nwarps); + FATTN_SWITCH_CASE(256, 32, nwarps); + // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 68d793bee816e44876b22232891ee7bab51ee5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 15:54:50 +0200 Subject: [PATCH 068/121] no ncols == 64 --- ggml-cuda/fattn.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aa85244fc52ce..19108044e762f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; @@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(64, 8, nwarps); FATTN_SWITCH_CASE(64, 16, nwarps); FATTN_SWITCH_CASE(64, 32, nwarps); - FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(80, 8, nwarps); FATTN_SWITCH_CASE(80, 16, nwarps); FATTN_SWITCH_CASE(80, 32, nwarps); - // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(96, 8, nwarps); FATTN_SWITCH_CASE(96, 16, nwarps); FATTN_SWITCH_CASE(96, 32, nwarps); - FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(112, 8, nwarps); FATTN_SWITCH_CASE(112, 16, nwarps); FATTN_SWITCH_CASE(112, 32, nwarps); - // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(128, 8, nwarps); FATTN_SWITCH_CASE(128, 16, nwarps); FATTN_SWITCH_CASE(128, 32, nwarps); - // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(256, 8, nwarps); FATTN_SWITCH_CASE(256, 16, nwarps); FATTN_SWITCH_CASE(256, 32, nwarps); - // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 3f777acf06a0c21780e2e6b1d5b0b8e9a2fbd922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 16:41:56 +0200 Subject: [PATCH 069/121] Multiple parallel blocks for batch size 1 --- ggml-cuda/fattn.cu | 417 ++++++++++++++++++++++++++++++--------------- 1 file changed, 283 insertions(+), 134 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 19108044e762f..4b51f1b747cf3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -template // D == head size -__launch_bounds__(D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne31*blockIdx.x; + const half * maskh = (const half *) mask; + + if (parallel_blocks == 1) { + Q_f2 += blockIdx.x*nb01/sizeof(float2); + maskh += blockIdx.x*ne11; + } const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - constexpr int nwarps = D/WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < nwarps*WARP_SIZE); - __shared__ half KQ[D]; - KQ[tid] = 0.0f; + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; half kqmax = -INFINITY; @@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16( kqmax_shared[threadIdx.x] = -INFINITY; kqsum_shared[threadIdx.x] = 0.0f; } - __syncthreads(); // Convert Q to half2 and store in registers: @@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -153,19 +161,25 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + if (tid < D) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; - VKQ += V_k*KQ2[k0/2]; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } } } + if (tid >= D) { + kqsum = 0.0f; + } + kqsum = warp_reduce_sum(kqsum); if (threadIdx.x == 0) { kqsum_shared[threadIdx.y] = kqsum; @@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16( kqsum = kqsum_shared[threadIdx.x]; kqsum = warp_reduce_sum(kqsum); - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; -} + if (tid >= D) { + return; + } -#define FATTN_KQ_STRIDE 256 + if (parallel_blocks == 1) { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + } else { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel + if (tid == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); + } + } +} + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + const half2 * mask2 = (half2 *) mask; + + if (parallel_blocks == 1) { + Q_f += blockIdx.x * ncols*nb01/sizeof(float); + mask2 += blockIdx.x * ncols*ne11/2; + } const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + if (parallel_blocks == 1) { + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } else { + KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } } @@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -420,22 +455,75 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + if (parallel_blocks == 1) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } + return; + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + } + + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const half2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half2 meta[parallel_blocks]; + if (tid < parallel_blocks) { + meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; } + + __syncthreads(); + + half kqmax = __low2half(meta[0]); +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = __hmax(kqmax, __low2half(meta[l])); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } constexpr int get_max_power_of_2(int x) { @@ -462,26 +550,26 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, nullptr, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -508,88 +596,135 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { - const int nwarps = Q->ne[0] / WARP_SIZE; - const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + if (Q->ne[1] == 1) { + constexpr int parallel_blocks = 4; + + ggml_cuda_pool_alloc dst_tmp(ctx.pool()); + ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); + + const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; + + // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: + constexpr int nwarps_tc = 4; + constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); + + const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); + const dim3 block_dim_combine(Q->ne[0], 1, 1); + const int shmem_combine = 0; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + switch (Q->ne[0]) { - // case 64: - // flash_attn_vec_ext_f16<64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 80: - // flash_attn_vec_ext_f16<80> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 96: - // flash_attn_vec_ext_f16<96> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 112: - // flash_attn_vec_ext_f16<112> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + case 64: + flash_attn_vec_ext_f16<64, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<64, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 80: + flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<80, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 96: + flash_attn_vec_ext_f16<96, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<96, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 112: + flash_attn_vec_ext_f16<112, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<112, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; case 128: - flash_attn_vec_ext_f16<128> + flash_attn_vec_ext_f16<128, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -598,15 +733,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<128, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; case 256: - flash_attn_vec_ext_f16<256> + flash_attn_vec_ext_f16<256, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<256, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; default: GGML_ASSERT(false); @@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = 4; + constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; From e1ecd3b1290adf0086b7b5a45fe97d22afe6a963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 10:27:34 +0200 Subject: [PATCH 070/121] fix compile warnings --- ggml-cuda/fattn.cu | 48 ++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4b51f1b747cf3..d1d018dc7a7e8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -#else - GGML_UNUSED(x); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} +// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// #pragma unroll +// for (int mask = 16; mask > 0; mask >>= 1) { +// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); +// } +// return x; +// #else +// GGML_UNUSED(x); +// NO_DEVICE_CODE; +// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// } #define FATTN_KQ_STRIDE 256 @@ -472,21 +472,20 @@ static __global__ void flash_attn_ext_f16( dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } - return; - } - + } else { #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; - } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } } } @@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } else { cols_per_block = 8; } - const int frag_m = cols_per_block == 8 ? 32 : 16; constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); From bb0d51accd7e99c151088c11f1b6774a753dc05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:13:46 +0200 Subject: [PATCH 071/121] fix excessive KQ_b loads --- ggml-cuda/fattn.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d1d018dc7a7e8..dcd2129a2cfc9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*kqs_padded + k, + kqs_padded); } } @@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } From c63dfdf765c48a0f78e162d0c02c7d69cbbc3083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:58:59 +0200 Subject: [PATCH 072/121] fix cmake build --- ggml-cuda/common.cuh | 26 ++++++++++++-------------- ggml-cuda/fattn.cu | 38 ++++++++++++-------------------------- 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index c245dd6ac009a..510ca6281471e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -271,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} #if defined(GGML_USE_HIPBLAS) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dcd2129a2cfc9..1d29346c7453b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,32 +3,6 @@ #include -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// #pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -// #else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// } - #define FATTN_KQ_STRIDE 256 template // D == head size @@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); @@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA //In this kernel Q, K, V are matrices while i, j, k are matrix indices. static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); @@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16( __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA } template // D == head size @@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results( } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } constexpr int get_max_power_of_2(int x) { From ee19a4ab7eba53347bde5fa3b3d38e38f7427f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 17:26:22 +0200 Subject: [PATCH 073/121] fix KV cache padding, NaN from INFINITY (#6438) --- ggml-cuda/fattn.cu | 15 +++++++++------ llama.cpp | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1d29346c7453b..91ef5551e025a 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include #define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16( KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; - half kqmax = -INFINITY; + half kqmax = -HALF_MAX_HALF; half kqsum = 0.0f; __shared__ half kqmax_shared[WARP_SIZE]; __shared__ half kqsum_shared[WARP_SIZE]; if (threadIdx.y == 0) { - kqmax_shared[threadIdx.x] = -INFINITY; + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; kqsum_shared[threadIdx.x] = 0.0f; } __syncthreads(); @@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16( if (tid < D) { #pragma unroll for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16( __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; + half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + ggml_cuda_set_device(ctx.device); const cudaStream_t main_stream = ctx.stream(); diff --git a/llama.cpp b/llama.cpp index 9ea9886fe3b89..b50588e4467b0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9973,7 +9973,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; From 599ce84a71512b72bf4fd6a248e7725f646eb1a8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Apr 2024 12:00:35 +0300 Subject: [PATCH 074/121] llama : flash_attn cparam + fix defrag --- common/common.cpp | 6 + common/common.h | 1 + llama.cpp | 345 +++++++++++++++++++++++++--------------------- llama.h | 1 + 4 files changed, 194 insertions(+), 159 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cf69535e2d1f5..fbff8cf13effc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index cca44268e6df5..78a1a34021b60 100644 --- a/common/common.h +++ b/common/common.h @@ -148,6 +148,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/llama.cpp b/llama.cpp index 1081899fbaa3f..0b5ee9ef569c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -107,8 +107,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 -#define LLAMA_FLASH_ATTN - // // logging // @@ -1899,6 +1897,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -5938,15 +5937,17 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -5959,28 +5960,26 @@ static void llm_build_kv_store( // important: storing RoPE-ed version of K in the KV cache! ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); -#if defined(LLAMA_FLASH_ATTN) - // NOTE: the V cache is not transposed when using FLASH attention !! - struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); - cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); - GGML_UNUSED(n_ctx); -#else - // compute the transposed [n_tokens, n_embd] V matrix - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + } else { + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); + cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); -#endif + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + } } static struct ggml_tensor * llm_build_norm( @@ -6111,6 +6110,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6118,12 +6118,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6143,97 +6143,100 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; -#if defined(LLAMA_FLASH_ATTN) - GGML_UNUSED(model); - GGML_UNUSED(n_ctx); + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.v_l[il]->type, n_embd_head_k), - 0); - cb(v, "v", il); + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); -#else - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.f_max_alibi_bias > 0.0f) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } - GGML_ASSERT(kv.size == n_ctx); + GGML_ASSERT(kv.size == n_ctx); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); -#endif + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } ggml_build_forward_expand(graph, cur); @@ -6253,6 +6256,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6262,7 +6266,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6276,12 +6279,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6323,6 +6326,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6369,6 +6374,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6483,15 +6489,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6640,9 +6662,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6846,9 +6868,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6953,9 +6975,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7073,9 +7095,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7198,9 +7220,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7343,9 +7365,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7447,9 +7469,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7651,9 +7673,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7747,9 +7769,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8040,9 +8062,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8171,14 +8193,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8320,9 +8343,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8438,9 +8461,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8557,9 +8580,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8671,9 +8694,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8817,9 +8840,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -8919,9 +8942,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9022,9 +9045,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9129,9 +9152,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9245,9 +9268,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9362,9 +9385,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9492,9 +9515,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9613,9 +9636,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9732,9 +9755,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10022,9 +10045,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -11016,7 +11039,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -14626,6 +14651,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -14795,6 +14821,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; diff --git a/llama.h b/llama.h index b5da686f7b7e5..77c288eb25543 100644 --- a/llama.h +++ b/llama.h @@ -270,6 +270,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted From 405385726ef7432b65b9e63dd7a63c18765eb376 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 14:05:02 +0200 Subject: [PATCH 075/121] server: support flash_attn param --- examples/server/server.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 634e653ada284..f1754b60b7fe8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2722,6 +2722,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; + } else if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) { invalid_param = true; From 5668c79ea092b7bff95e1fce96e3de717c31349d Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 23:26:29 +0200 Subject: [PATCH 076/121] server: bench: enable flash_attn param --- examples/server/bench/bench.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 6ca637bddc3a1..86c5de101445c 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -268,6 +268,7 @@ def start_server_background(args): server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') + server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From 34f93bbb39965dd40fe8ad717902d6d109b64afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Apr 2024 11:39:16 +0200 Subject: [PATCH 077/121] CUDA: refactor host code, dyn. par. blocks --- ggml-cuda.cu | 1 + ggml-cuda/common.cuh | 6 + ggml-cuda/fattn.cu | 542 +++++++++++++++++++------------------------ 3 files changed, 248 insertions(+), 301 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 11adbabd655e1..2cf6c8d98bd89 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -141,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b0149b7be22b3..989780dbce88c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -390,6 +390,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -403,6 +408,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 91ef5551e025a..5f1345a7fe94f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask; - - if (parallel_blocks == 1) { - Q_f2 += blockIdx.x*nb01/sizeof(float2); - maskh += blockIdx.x*ne11; - } + const half * maskh = (const half *) mask + ne11*ic; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); @@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; @@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16( return; } + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); if (parallel_blocks == 1) { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; - } else { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; - if (tid == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); - } + if (parallel_blocks == 1 || tid != 0) { + return; } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; @@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask; - - if (parallel_blocks == 1) { - Q_f += blockIdx.x * ncols*nb01/sizeof(float); - mask2 += blockIdx.x * ncols*ne11/2; - } + const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - if (parallel_blocks == 1) { - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } else { - KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } - if (parallel_blocks == 1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; } - } else { + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + half dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; } + + half2 dst_meta_val = KQ_max[j0/nwarps]; + reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // FP16_MMA_AVAILABLE } template // D == head size @@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } constexpr int get_max_power_of_2(int x) { @@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, nullptr, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - const cudaStream_t main_stream = ctx.stream(); - - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); - - if (Q->ne[1] == 1) { + if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { constexpr int parallel_blocks = 4; - - ggml_cuda_pool_alloc dst_tmp(ctx.pool()); - ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - - const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const int shmem = 0; - - // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: - constexpr int nwarps_tc = 4; - constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - - const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); - const dim3 block_dim_combine(Q->ne[0], 1, 1); - const int shmem_combine = 0; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - switch (Q->ne[0]) { case 64: - flash_attn_vec_ext_f16<64, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<64, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 80: - flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<80, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 96: - flash_attn_vec_ext_f16<96, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<96, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 112: - flash_attn_vec_ext_f16<112, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<112, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - flash_attn_vec_ext_f16<128, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<128, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - flash_attn_vec_ext_f16<256, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<256, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); return; } - int cols_per_block; - if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } - constexpr int nwarps = 4; - const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const size_t shmem = 0; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - switch (Q->ne[0]) { - case 64: switch (cols_per_block) { - FATTN_SWITCH_CASE(64, 8, nwarps); - FATTN_SWITCH_CASE(64, 16, nwarps); - FATTN_SWITCH_CASE(64, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 80: switch (cols_per_block) { - // FATTN_SWITCH_CASE(80, 8, nwarps); - FATTN_SWITCH_CASE(80, 16, nwarps); - FATTN_SWITCH_CASE(80, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 96: switch (cols_per_block) { - FATTN_SWITCH_CASE(96, 8, nwarps); - FATTN_SWITCH_CASE(96, 16, nwarps); - FATTN_SWITCH_CASE(96, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 112: switch (cols_per_block) { - // FATTN_SWITCH_CASE(112, 8, nwarps); - FATTN_SWITCH_CASE(112, 16, nwarps); - FATTN_SWITCH_CASE(112, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 128: switch (cols_per_block) { - FATTN_SWITCH_CASE(128, 8, nwarps); - FATTN_SWITCH_CASE(128, 16, nwarps); - FATTN_SWITCH_CASE(128, 32, nwarps); default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; - case 256: switch (cols_per_block) { - FATTN_SWITCH_CASE(256, 8, nwarps); - FATTN_SWITCH_CASE(256, 16, nwarps); - FATTN_SWITCH_CASE(256, 32, nwarps); + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); + return; } From 6a3b84236de279f0fe012cfca0c168472526b696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 13 Apr 2024 22:05:43 +0200 Subject: [PATCH 078/121] fix flash_attn_vec_f16 race condition --- ggml-cuda/fattn.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 5f1345a7fe94f..36479b2170979 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16( VKQ += V_k*KQ2[k0/2]; } } + + __syncthreads(); } if (tid >= D) { @@ -547,7 +549,7 @@ template void launch_fattn_vec_f16( dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } - constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -561,7 +563,7 @@ template void launch_fattn_vec_f16( (const char *) K->data, (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -572,7 +574,7 @@ template void launch_fattn_vec_f16( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { + if (parallel_blocks == 1) { return; } From ef9e1593f33df5dbc8f89b927a8a7bd9dfc9e6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 15 Apr 2024 16:05:07 +0200 Subject: [PATCH 079/121] flush softmax exp below threshold to 0 --- ggml-cuda/fattn.cu | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 36479b2170979..f6289822e0ea0 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,8 +3,9 @@ #include -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, val); + KQ2[j*(kqs_padded/2) + k] = val; } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; + KQ_max_scale[j0/nwarps] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; KQ_max[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - val = h2exp(val - KQ_max[j0/nwarps]); + const half2 diff = val - KQ_max[j0/nwarps]; + val = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &val) &= ftz_mask; KQ_rowsum_add += val; KQ2[j*(kqs_padded/2) + k] = val; } @@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + const half diff = __low2half(meta[l]) - kqmax; + float KQ_max_scale = hexp(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * __high2float(meta[l]); From a5b0e2dea018cfac5ee478aac0d780eef391b30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 15:58:21 +0200 Subject: [PATCH 080/121] store temp KQ in registers --- ggml-cuda/fattn.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f6289822e0ea0..b889cdb3b9b01 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, val); - KQ2[j*(kqs_padded/2) + k] = val; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; @@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - const half2 diff = val - KQ_max[j0/nwarps]; - val = h2exp(diff); + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &val) &= ftz_mask; - KQ_rowsum_add += val; - KQ2[j*(kqs_padded/2) + k] = val; + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); From 0bc67dd1c81c15f04096985f9a85d81b431767b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 16:22:29 +0200 Subject: [PATCH 081/121] Calculate KQ as FP32 if KQV has GGML_PREC_F32 --- ggml-cuda/fattn.cu | 286 +++++++++++++++++++++++++++++++++------------ 1 file changed, 213 insertions(+), 73 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index b889cdb3b9b01..dda344531335c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn.cuh" +#include #include #define FATTN_KQ_STRIDE 256 @@ -185,7 +186,8 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -229,7 +231,8 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -238,12 +241,14 @@ static __global__ void flash_attn_ext_f16( // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: constexpr int D_padded = D + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -251,14 +256,29 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_KQ = ncols*kqs_padded*kqar; constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; - half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -307,7 +327,7 @@ static __global__ void flash_attn_ext_f16( // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c KQ_c[ncols/frag_n]; + frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); @@ -323,7 +343,7 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -335,45 +355,90 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - } + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } - half2 KQ_max_new = KQ_max[j0/nwarps]; + float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; - KQ_max_scale[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; - KQ_max[j0/nwarps] = KQ_max_new; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } } __syncthreads(); @@ -386,12 +451,12 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; nvcuda::wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*kqs_padded + k, - kqs_padded); + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); } } - frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll @@ -431,6 +496,14 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -443,7 +516,7 @@ static __global__ void flash_attn_ext_f16( for (int l = 0; l < VKQ_ratio; ++l) { VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -458,14 +531,20 @@ static __global__ void flash_attn_ext_f16( } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + #pragma unroll for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D && i >= D) { break; } - half dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*D_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } @@ -476,7 +555,12 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val = KQ_max[j0/nwarps]; + half2 dst_meta_val; + if (std::is_same::value) { + reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val = KQ_max_h2[j0/nwarps]; + } reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } @@ -606,7 +690,7 @@ template void launch_fattn_vec_f16( CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16_impl( +template void launch_fattn_f16_impl( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { @@ -626,7 +710,7 @@ template void launc float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - flash_attn_ext_f16 + flash_attn_ext_f16 <<>> ( (const char *) Q->data, (const char *) K->data, @@ -657,21 +741,21 @@ template void launc CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16( +template void launch_fattn_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream ) { const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; if (4*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } if (2*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -696,15 +780,73 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { constexpr int parallel_blocks = 4; switch (Q->ne[0]) { case 64: launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; - case 96: - launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; case 128: launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; @@ -718,23 +860,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -748,22 +888,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -776,22 +916,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); From 2f538b9547ec2c2c67be0d41ed96d33c141354fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 16:29:28 +0200 Subject: [PATCH 082/121] Add __hgt2_mask implementation for CUDA 11 --- ggml-cuda/common.cuh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 989780dbce88c..ac6de643d668e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -306,6 +306,13 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { + const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 From 87968de9a99d9820d20aeb5a15211f4ab5efde83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 17:31:03 +0200 Subject: [PATCH 083/121] fix KQ FP32 precision fpr parallel_blocks > 1 --- ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dda344531335c..4cf2907e8d10c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16( if (parallel_blocks == 1 || tid != 0) { return; } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val; + float2 dst_meta_val; if (std::is_same::value) { - reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + dst_meta_val.x = KQ_max_f[j0/nwarps]; } else { - dst_meta_val = KQ_max_h2[j0/nwarps]; + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } - reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta_val.y = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else @@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const half2 * __restrict__ VKQ_meta, + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { #if FP16_AVAILABLE VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; @@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results( const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ half2 meta[parallel_blocks]; - if (tid < parallel_blocks) { - meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); - half kqmax = __low2half(meta[0]); + float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { - kqmax = __hmax(kqmax, __low2half(meta[l])); + kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - const half diff = __low2half(meta[l]) - kqmax; - float KQ_max_scale = hexp(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; @@ -643,8 +643,8 @@ template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -694,8 +694,8 @@ template dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); From 260cdb2d082d1658d1c6b693c4cbf77754873886 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:28:19 +0300 Subject: [PATCH 084/121] llama-bench : add -fa,--flash-attn arg --- examples/llama-bench/llama-bench.cpp | 30 +++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8b532c8b6a98a..95c3095dd04da 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -174,6 +174,7 @@ struct cmd_params { std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; + std::vector flash_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, + /* flash_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -498,6 +509,7 @@ struct cmd_params_instance { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -532,6 +544,7 @@ struct cmd_params_instance { cparams.type_k = type_k; cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; return cparams; @@ -554,6 +567,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) + for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -572,6 +586,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -596,6 +611,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -633,6 +649,7 @@ struct test { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -657,6 +674,7 @@ struct test { split_mode = inst.split_mode; main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -731,7 +749,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", + "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -753,7 +771,7 @@ struct test { } if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -787,7 +805,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -955,6 +973,9 @@ struct markdown_printer : public printer { if (field == "no_kv_offload") { return "nkvo"; } + if (field == "flash_attn") { + return "fa"; + } if (field == "use_mmap") { return "mmap"; } @@ -1001,6 +1022,9 @@ struct markdown_printer : public printer { if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) { fields.emplace_back("no_kv_offload"); } + if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { + fields.emplace_back("flash_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } From 105332cc17b8bd8f3989606b59489ed85eefe04f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:33:07 +0300 Subject: [PATCH 085/121] metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel --- ggml-metal.m | 119 ++++++++++++++------ ggml-metal.metal | 274 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 361 insertions(+), 32 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e8613dbee04d2..407f94eb224cf 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -183,6 +183,8 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -621,12 +623,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); - } + if (ne01 > 1 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } } // TODO: extend if necessary @@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + // half8x8 kernel + if (ne01 > 1 || (ne00%128 != 0)) { + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); - // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + // require power of 2 + //{ + // int64_t nsgm = 1; + // while (nsgm < nsg) { + // nsgm *= 2; + // } + // GGML_ASSERT(nsg == nsgm); + //} + + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index be47db86ebbd3..de6072e93470a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2494,6 +2494,280 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short D8 = D/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + 1); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S = { 0.0h }; + half M = { -HALF_MAX_HALF }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += mq[i] * mk; + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + half4 mm = mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const half m = M; + const half s = ss[p]; + + M = simd_max(max(M, s)); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const half S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From c16a7c26882669a0d2ed7ef592cde3b0227248f2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 20:08:52 +0300 Subject: [PATCH 086/121] metal : use F32 attention accumulators --- ggml-metal.m | 15 +---- ggml-metal.metal | 156 +++++++++++++++++++++++------------------------ ggml.c | 3 +- 3 files changed, 81 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 407f94eb224cf..f4a831b52a9f9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); @@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); int64_t nsg = 1; @@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute( } nsg /= 2; - // require power of 2 - //{ - // int64_t nsgm = 1; - // while (nsgm < nsg) { - // nsgm *= 2; - // } - // GGML_ASSERT(nsg == nsgm); - //} - - const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index de6072e93470a..36b87b2f0750a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2169,12 +2169,13 @@ kernel void kernel_flash_attn_ext_f16( const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) const short T4 = T/4; // shared memory size per query in (half4) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2202,15 +2203,15 @@ kernel void kernel_flash_attn_ext_f16( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*T + i] = 0.0h; + ss[j*TF + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -INFINITY }; + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2248,7 +2249,7 @@ kernel void kernel_flash_attn_ext_f16( device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2261,9 +2262,9 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk[Q8]; + simdgroup_float8x8 mqk[Q8]; for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); + mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2283,48 +2284,48 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); } } } // used to detect blocks full of -INF - half smax = -INFINITY; + float smax = -INFINITY; // online softmax if (C == 32) { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const float m = M[j]; + const float s = ss[j*TF + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } else { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { - const half m = M[j]; + const float m = M[j]; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; smax = max(smax, s); M[j] = max(M[j], s); @@ -2333,20 +2334,20 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + ms[j] = exp(m - M[j]); // local sum - half ls = 0.0h; + float ls = 0.0h; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const float vs = exp(s - M[j]); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } S[j] = S[j]*ms[j] + simd_sum(ls); @@ -2354,7 +2355,7 @@ kernel void kernel_flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } @@ -2365,8 +2366,8 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); @@ -2383,8 +2384,8 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mv; - simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); } @@ -2396,16 +2397,16 @@ kernel void kernel_flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; } } } // reduce the warps sequentially for (short sg = 1; sg < nsg; ++sg) { - half S = { 0.0h }; - half M = { -INFINITY }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2423,36 +2424,36 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; M = max(M0, M1); - const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; - simdgroup_half8x8 ms0; - simdgroup_half8x8 ms1; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); - simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); @@ -2478,7 +2479,7 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; + const float S = ss[j*TF + 0]; for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; @@ -2494,8 +2495,6 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; -#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. - template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, @@ -2539,18 +2538,16 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq1 = tgpig[0]; const short D4 = D/4; - const short D8 = D/8; const short NW = N_SIMDWIDTH; const short SH = (C + 1); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) - const short T4 = T/4; // shared memory size per query in (half4) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) half4 lo[D4/NW]; @@ -2579,8 +2576,8 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S = { 0.0h }; - half M = { -HALF_MAX_HALF }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2628,7 +2625,7 @@ kernel void kernel_flash_attn_ext_vec_f16( { #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - half4 mqk = { 0.0h }; + float4 mqk = { 0.0h }; device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2642,7 +2639,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = pk4[i + 2*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)]; - mqk += mq[i] * mk; + mqk += (float4) (mq[i] * mk); } // reduce the results from the threads in the simdgroup @@ -2654,7 +2651,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask if (tiisg == 0) { - half4 mm = mp4[ic/4 + cc]; + float4 mm = (float4) mp4[ic/4 + cc]; mqk = mqk*scale + mm; ss4[cc] = mqk; @@ -2666,13 +2663,13 @@ kernel void kernel_flash_attn_ext_vec_f16( { const short p = tiisg; - const half m = M; - const half s = ss[p]; + const float m = M; + const float s = ss[p]; M = simd_max(max(M, s)); - const half ms = exp(m - M); - const half vs = exp(s - M); + const float ms = exp(m - M); + const float vs = exp(s - M); S = S*ms + simd_sum(vs); @@ -2696,6 +2693,7 @@ kernel void kernel_flash_attn_ext_vec_f16( #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; @@ -2724,18 +2722,18 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const half S0 = ss[ 0]; - const half S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; - const half M0 = ss[ 1]; - const half M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; - const half M = max(M0, M1); + const float M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[0] = S; @@ -2756,7 +2754,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - const half S = ss[0]; + const float S = ss[0]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; diff --git a/ggml.c b/ggml.c index f50cb948daeab..f2bbfa6f2273b 100644 --- a/ggml.c +++ b/ggml.c @@ -14882,12 +14882,13 @@ static void ggml_compute_forward_flash_attn_ext( struct ggml_tensor * dst) { switch (dst->op_params[1]) { case GGML_PREC_DEFAULT: + case GGML_PREC_F32: { + // uses F32 accumulators ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { - // TODO: implement F32 precision GGML_ASSERT(false); } break; } From 9ca869876eddfb0c51d3b80440a11abdd6a3fe18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 21:41:32 +0300 Subject: [PATCH 087/121] batched-bench : add fattn arg --- examples/batched-bench/batched-bench.cpp | 28 ++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1e34de620a41b..2924d8116f44f 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,7 +32,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; @@ -41,6 +41,7 @@ int main(int argc, char ** argv) { int n_kv_max = 2048; int n_batch = 2048; int n_ubatch = 512; + bool flash_attn = false; int is_pp_shared = 0; int n_gpu_layers = 0; @@ -66,23 +67,27 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - is_pp_shared = std::atoi(argv[5]); + flash_attn = std::atoi(argv[5]); } if (argc >= 7) { - n_gpu_layers = std::atoi(argv[6]); + is_pp_shared = std::atoi(argv[6]); } if (argc >= 8) { - n_pp = parse_list(argv[7]); + n_gpu_layers = std::atoi(argv[7]); } if (argc >= 9) { - n_tg = parse_list(argv[8]); + n_pp = parse_list(argv[8]); } if (argc >= 10) { - n_pl = parse_list(argv[9]); + n_tg = parse_list(argv[9]); + } + + if (argc >= 11) { + n_pl = parse_list(argv[10]); } // init LLM @@ -108,10 +113,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = n_batch; - ctx_params.n_ubatch = n_ubatch; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = n_batch; + ctx_params.n_ubatch = n_ubatch; + ctx_params.flash_attn = flash_attn; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -169,7 +175,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); From 74d57f95136c8391756c8144ad12a517901bd2e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 13:49:57 +0300 Subject: [PATCH 088/121] llama : simplify llama_build_kv_store ggml-ci --- llama.cpp | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4034d25181aaa..d828dd786e679 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5963,29 +5963,27 @@ static void llm_build_kv_store( (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! + // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using FLASH attention !! - struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); - cb(v_cache_view, "v_cache_view", il); + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); - } else { - // compute the transposed [n_tokens, n_embd] V matrix - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = nullptr; - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + v_cur = ggml_transpose(ctx, v_cur); } + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } static struct ggml_tensor * llm_build_norm( @@ -6169,11 +6167,6 @@ static struct ggml_tensor * llm_build_kqv( if (model.arch == LLM_ARCH_PHI2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); } else { @@ -14879,6 +14872,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From e32b281743c98411d1eaf22823d3eea2023c3502 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 14:04:56 +0300 Subject: [PATCH 089/121] llama : adapt build_olmo to changes --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index add75818b2bdc..a7ce50dd30efa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10287,9 +10287,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { From 703c6e6528d184eaf6ea0bed21cd350e095c63f4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 14:20:41 +0300 Subject: [PATCH 090/121] ggml : fix arm fp16 store on windows --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 4c749ecdafa5b..76ca79e660aa0 100644 --- a/ggml.c +++ b/ggml.c @@ -963,7 +963,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -989,7 +989,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 97eaece7d6537b97a01d30a71a7ce4511fa97e25 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 15:30:27 +0300 Subject: [PATCH 091/121] metal : clean-up --- ggml-metal.m | 353 ++++++++++++++++++++++++++------------------------- 1 file changed, 178 insertions(+), 175 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 44a3d05f354e5..68eb49d5b1baa 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -475,175 +475,175 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -2560,7 +2560,9 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ne01 > 1 || (ne00%128 != 0)) { + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; @@ -2576,6 +2578,8 @@ static enum ggml_status ggml_metal_graph_compute( } } } else { + use_vec_kernel = true; + switch (ne00) { case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; @@ -2588,7 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute( } } - // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2619,8 +2622,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // half8x8 kernel - if (ne01 > 1 || (ne00%128 != 0)) { + if (!use_vec_kernel) { + // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! @@ -2635,7 +2638,7 @@ static enum ggml_status ggml_metal_graph_compute( //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { @@ -2926,7 +2929,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - ggml_backend_metal_log_allocated_size(device, size_aligned); + //ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } From 1a88565b4489381923aa0c9a6741badfb6766b23 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 15:52:49 +0300 Subject: [PATCH 092/121] metal : clean-up kernel code --- ggml-metal.metal | 142 ++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 99 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 1ed5632b4e0ec..32cbef9dca103 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)( ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[Q8][D8]; + simdgroup_half8x8 lo[D8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { @@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - lo[j][i] = make_filled_simdgroup_matrix(0.0h); - } + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } // zero out shared memory SH @@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16( const short rv3 = ne03/ne23; // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[Q8][D8]; + simdgroup_half8x8 mq[D8]; - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); - } + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); } // pointer to the mask @@ -2262,10 +2258,7 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk[Q8]; - for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); - } + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2273,19 +2266,15 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (short j = 0; j < Q8; ++j) { - simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); - } + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } // mqk = mqk*scale + mask - for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); - } + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } } @@ -2293,7 +2282,7 @@ kernel void kernel_flash_attn_ext_f16( float smax = -INFINITY; // online softmax - if (C == 32) { + { float ms[Q]; for (short j = 0; j < Q; ++j) { @@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16( ss[j*TF + p] = vs; } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } else { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const float m = M[j]; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - smax = max(smax, s); - M[j] = max(M[j], s); - } - - smax = simd_max(smax); - M[j] = simd_max(M[j]); - - ms[j] = exp(m - M[j]); - - // local sum - float ls = 0.0h; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - const float vs = exp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } - - S[j] = S[j]*ms[j] + simd_sum(ls); - } - // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { ss[tiisg*TF + C + tiisg] = ms[tiisg]; @@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (short j = 0; j < Q8; ++j) { + { simdgroup_float8x8 mm; - simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(mm, ss + C, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[j][i], mm, lo[j][i]); + simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (short j = 0; j < Q8; ++j) { - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); - simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); - } + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); } } } @@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2447,19 +2393,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short j = 0; j < Q8; ++j) { + { simdgroup_half8x8 t; simdgroup_float8x8 ms0; simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); - simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); + simdgroup_load (t, sq + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); } } } @@ -2467,10 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, device const char * k, @@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const short D4 = D/4; const short NW = N_SIMDWIDTH; - const short SH = (C + 1); // shared memory per simdgroup in (half) + const short SH = (C + Q); // shared memory per simdgroup in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half) @@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; kernel void kernel_cpy_f16_f16( device const half * src0, From bc346166f96b13cf62291fe8a0212c98e561645c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:24:52 +0300 Subject: [PATCH 093/121] metal : minor --- ggml-metal.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 68eb49d5b1baa..aa22a24f01c38 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -451,7 +451,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ } /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -461,7 +461,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ @@ -470,7 +470,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ return NULL; \ } \ } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 From 52945429eb43be1c8b89a77719513cd088e29586 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:38:28 +0300 Subject: [PATCH 094/121] tests : remove benchmarks ggml-ci --- tests/test-backend-ops.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1cc5e53bd2442..2317b8b7e1fab 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -15,6 +15,9 @@ #include #include +// TODO: remove before merging +//#define TMP_ATTN_BENCH + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { // static RNG initialization (revisit if n_threads stops being constant) static const size_t n_threads = std::thread::hardware_concurrency(); @@ -571,7 +574,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 0 +#ifndef TMP_ATTN_BENCH for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -1513,8 +1516,8 @@ struct test_flash_attn_ext : public test_case { } }; +#ifdef TMP_ATTN_BENCH // ATTN -// TODO: this is temporary until the FA branch is merged struct test_attn : public test_case { const int64_t hs; // head size const int64_t nh; // num heads @@ -1555,6 +1558,7 @@ struct test_attn : public test_case { return cur; } }; +#endif enum llm_norm_type { LLM_NORM, @@ -2220,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); -#if 1 +#ifdef TMP_ATTN_BENCH for (int hs : { 128, 256, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -2232,11 +2236,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } #else - for (int hs : { 128, }) { + for (int hs : { 64, 80, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, 512 }) { - test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + for (int nb : { 1, 2, 4, 8, }) { test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); } } From 3badef1fe143764c1298a45cd0986a737eba0a8d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 17:45:08 +0300 Subject: [PATCH 095/121] ggml : fix avx512 const correctness ggml-ci --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 76ca79e660aa0..1d88e0da246e9 100644 --- a/ggml.c +++ b/ggml.c @@ -1058,7 +1058,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) From 871fcb6e101dc3fdc92fce273a7c932d16b72d8a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 18:03:56 +0300 Subject: [PATCH 096/121] ggml : fix soft_max with bias on CPU ggml-ci --- ggml.c | 4 ++-- tests/test-backend-ops.cpp | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 1d88e0da246e9..41557ab6766c5 100644 --- a/ggml.c +++ b/ggml.c @@ -12410,7 +12410,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data; for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); @@ -12433,7 +12433,7 @@ static void ggml_compute_forward_soft_max_f32( const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2317b8b7e1fab..ce39dadbb61e3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1103,6 +1103,12 @@ struct test_soft_max : public test_case { return VARS_TO_STR5(type, ne, mask, scale, max_bias); } + // the 1024 test with bias occasionally fails: + // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL + virtual double max_nmse_err() override { + return 1e-6; + } + test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, bool mask = false, @@ -2180,7 +2186,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); } } From a39217d4285b44c1b916c949ef6581e82f3c3ef3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:10 +0300 Subject: [PATCH 097/121] common : print --flash-attn in help --- common/common.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/common.cpp b/common/common.cpp index fbff8cf13effc..a29c451aa94fc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1482,6 +1482,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); if (llama_supports_mlock()) { From cb76d747d166dd9bbd028d666b1e8e53fe10efa0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:26 +0300 Subject: [PATCH 098/121] ggml : fix num dimensions in ggml_flash_attn_ext --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 41557ab6766c5..b1c76e6789749 100644 --- a/ggml.c +++ b/ggml.c @@ -6321,7 +6321,7 @@ struct ggml_tensor * ggml_flash_attn_ext( // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); From c11d05fec03a9190cabe8f1afb58a381811d5e21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 12:50:41 +0300 Subject: [PATCH 099/121] llama : force disable flash attention for incompatible models --- llama.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index a7ce50dd30efa..a4b00e7ff3ddf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1823,7 +1823,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -6311,6 +6311,8 @@ static struct ggml_tensor * llm_build_kqv( GGML_UNUSED(model); GGML_UNUSED(n_ctx); + // note: if this assert triggers, then some check has failed earlier + // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); // split cached v into n_head heads (not transposed) @@ -15114,6 +15116,16 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && hparams.need_kq_pos) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); + cparams.flash_attn = false; + } + + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } From f725ca90fb77f32e52ee2c204708560c952fdf78 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 13:46:23 +0300 Subject: [PATCH 100/121] ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci --- ggml-cuda/softmax.cu | 46 +++++++++++++++++++++++++++++--------- ggml-metal.m | 29 ++++++++++++++++++------ ggml-metal.metal | 18 +++++++++++---- ggml.c | 38 +++++++++++++++++++++++-------- llama.cpp | 4 ++-- tests/test-backend-ops.cpp | 4 ++-- 6 files changed, 105 insertions(+), 34 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 8f6dca4d0f9bf..c0557db78df8d 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,17 @@ #include "softmax.cuh" -template -static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32(half val) { + return __half2float(val); +} + +template +static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha } } -static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; - const half * src1_d = src1 ? (const half *)src1->data : nullptr; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - half * src2_dd = nullptr; + void * src2_d = nullptr; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (half *)src2->data; + src2_d = (void *)src2->data; } - soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + const half * src2_dd = (const half *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } else { + const float * src1_dd = (const float *)src1_d; + const float * src2_dd = (const float *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } } diff --git a/ggml-metal.m b/ggml-metal.m index aa22a24f01c38..1903791f1f9e3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -46,8 +46,10 @@ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -492,8 +494,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -1346,22 +1350,33 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; diff --git a/ggml-metal.metal b/ggml-metal.metal index 32cbef9dca103..3d4276ae02b9e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -352,6 +352,7 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( device const char * src0, device const char * src1, @@ -376,8 +377,8 @@ kernel void kernel_soft_max( const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; - device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -456,6 +457,7 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, @@ -480,8 +482,8 @@ kernel void kernel_soft_max_4( const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; - device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -562,6 +564,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, diff --git a/ggml.c b/ggml.c index b1c76e6789749..bc19f35bff84f 100644 --- a/ggml.c +++ b/ggml.c @@ -5473,7 +5473,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(mask->ne[1] >= a->ne[1]); @@ -5481,10 +5481,14 @@ static struct ggml_tensor * ggml_soft_max_impl( if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F16); + GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { GGML_ASSERT(pos); } @@ -12410,20 +12414,30 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data; + ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - for (int i = 0; i < nc; ++i) { - wp[i] += GGML_FP16_TO_FP32(mp[i]); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } } } @@ -12432,8 +12446,14 @@ static void ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]); + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } diff --git a/llama.cpp b/llama.cpp index a4b00e7ff3ddf..26802d96a6752 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6710,14 +6710,14 @@ struct llm_build_context { } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_KQ_pos() { lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; } struct ggml_tensor * build_inp_mean() { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce39dadbb61e3..d044a6ea02481 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } ggml_tensor * pos = nullptr; if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]); + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); return out; From 5408d55506186b8e56f8a6a748688847ef0ebb7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 19:12:06 +0300 Subject: [PATCH 101/121] cuda : uint -> uint32_t --- ggml-cuda/common.cuh | 6 +++--- ggml-cuda/fattn.cu | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index ac6de643d668e..e82d63e4a0abc 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -307,9 +307,9 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { } #if CUDART_VERSION < 12000 -static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { - const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); - const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); return mask_low | mask_high; } #endif // CUDART_VERSION < 12000 diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4cf2907e8d10c..2077da53dc68f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -418,8 +418,8 @@ static __global__ void flash_attn_ext_f16( KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; KQ_max_h2[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -429,8 +429,8 @@ static __global__ void flash_attn_ext_f16( const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } @@ -602,8 +602,8 @@ static __global__ void flash_attn_combine_results( for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; const float KQ_max_scale = expf(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint *) &KQ_max_scale) &= ftz_mask; + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; From c70bfd7bcb5b218bea00cddee0dfca0a7d4e4c7f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 20:31:23 +0300 Subject: [PATCH 102/121] cuda : "constexpr dim3" -> "const dim3" ggml-ci --- ggml-cuda/fattn.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 2077da53dc68f..aaaea2f0701d8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -652,7 +652,7 @@ template void launch_fattn_vec_f16( } constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -680,9 +680,9 @@ template void launch_fattn_vec_f16( return; } - constexpr dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>> @@ -703,7 +703,7 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -731,9 +731,9 @@ template ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>> From c129369702655f0bafa06426fe8179c2c28d63ea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 21:42:43 +0300 Subject: [PATCH 103/121] cuda : try to fix __hgt2_mask ggml-ci --- ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index e82d63e4a0abc..ca0d85ae9ab70 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -308,8 +308,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if CUDART_VERSION < 12000 static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { - const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); - const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } #endif // CUDART_VERSION < 12000 From 3864eea4cbf4d224d8ce798e7c537838048f540a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 10:01:49 +0300 Subject: [PATCH 104/121] ggml : add TODO's for F16/F32 mask/pos support in other backends --- ggml-kompute.cpp | 7 +++++++ ggml-sycl.cpp | 6 +++++- ggml-vulkan.cpp | 5 +++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6fd476..9a469821d8042 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a9b310243f04f..f8ed55eb82259 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14738,7 +14738,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const ggml_tensor * src2 = dst->src[2]; + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14754,7 +14759,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, float * src2_dd = nullptr; sycl_pool_alloc src2_f; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 1736ab7361c27..f712cdd5a900e 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } From 78d363b0d4b776a0ead4246d6f3392df04a7044b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:15:13 +0300 Subject: [PATCH 105/121] llama : replace bool need_kq_pos with use_alibi --- llama.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 26802d96a6752..83e0c2ef11831 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1823,7 +1823,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -4104,7 +4104,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -6269,7 +6269,6 @@ static struct ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, @@ -6359,7 +6358,7 @@ static struct ggml_tensor * llm_build_kqv( #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { + if (hparams.use_alibi) { kq = ggml_scale(ctx, kq, kq_scale); cb(kq, "kq_scaled", il); @@ -10714,7 +10713,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; GGML_ASSERT(lctx.inp_KQ_pos); @@ -15116,7 +15117,7 @@ struct llama_context * llama_new_context_with_model( } } - if (cparams.flash_attn && hparams.need_kq_pos) { + if (cparams.flash_attn && hparams.use_alibi) { LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); cparams.flash_attn = false; } From 19e8982f51c5bea11735f688092627f1461d8759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:24:28 +0300 Subject: [PATCH 106/121] llama : prep ALiBi support for BERT models ggml-ci --- ggml.c | 1 + llama.cpp | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index bc19f35bff84f..469a0e0d9afd6 100644 --- a/ggml.c +++ b/ggml.c @@ -5476,6 +5476,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } diff --git a/llama.cpp b/llama.cpp index 83e0c2ef11831..4b38f5870adb0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6712,8 +6712,14 @@ struct llm_build_context { return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { + if (causal) { + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + } else { + // TODO: this will be needed for ALiBi-based BERT models + // https://github.com/ggerganov/llama.cpp/pull/6826 + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); + } cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; From 56657e52e5f2b0fdcd414f99837b2e67efcf824c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:30:37 +0300 Subject: [PATCH 107/121] llama : fix n_batch requirements ggml-ci --- llama.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4b38f5870adb0..fcd15501e416f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15064,10 +15064,6 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); - cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; @@ -15086,12 +15082,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : From d228bf8552f5a6afa0f4c523c0da4bc00312b791 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 17:32:11 +0300 Subject: [PATCH 108/121] cont --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index fcd15501e416f..a3624544cf326 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15092,7 +15092,7 @@ struct llama_context * llama_new_context_with_model( // ref: https://github.com/ggerganov/llama.cpp/pull/5021 if (cparams.n_batch < GGML_KQ_MASK_PAD) { LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_batch = GGML_KQ_MASK_PAD; } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); From 751591d52074d6be53feed6e19211932e4520159 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Apr 2024 18:16:25 +0300 Subject: [PATCH 109/121] server : add help for --flash-attn arg --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f1754b60b7fe8..2cf59fbe0d66b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2357,6 +2357,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n"); From ce281b904c0ed97f7f7c685a2c7bf8dbaa6f8293 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Apr 2024 16:48:10 +0300 Subject: [PATCH 110/121] llama : disable FA for AMD --- ggml-cuda/common.cuh | 4 ++-- ggml-cuda/fattn.cu | 3 +++ llama.cpp | 7 +++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index ca0d85ae9ab70..156eba6d1ef74 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -399,8 +399,8 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL -#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ - defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aaaea2f0701d8..df1e80068b334 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -2,7 +2,10 @@ #include "fattn.cuh" #include + +#if FP16_MMA_AVAILABLE #include +#endif #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. diff --git a/llama.cpp b/llama.cpp index 11a1aa3a44c26..f00190a77bb79 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15357,6 +15357,13 @@ struct llama_context * llama_new_context_with_model( cparams.flash_attn = false; } +#ifdef GGML_USE_HIPBLAS + if (cparams.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); + cparams.flash_attn = false; + } +#endif + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } From ff2c64a9f4190cd5da3c59f65d8417299c72a336 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 15:51:46 +0300 Subject: [PATCH 111/121] tests : remove TMP_ATTN_BENCH ggml-ci --- tests/test-backend-ops.cpp | 70 -------------------------------------- 1 file changed, 70 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d044a6ea02481..b27c1291e4088 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -15,9 +15,6 @@ #include #include -// TODO: remove before merging -//#define TMP_ATTN_BENCH - static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { // static RNG initialization (revisit if n_threads stops being constant) static const size_t n_threads = std::thread::hardware_concurrency(); @@ -574,19 +571,9 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#ifndef TMP_ATTN_BENCH for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } -#else - int n_nodes = gf->n_nodes; - n_runs = 1000; - for (int i = 1; i < n_runs; i++) { - for (int j = 0; j < n_nodes; j++) { - gf->nodes[gf->n_nodes++] = gf->nodes[j]; - } - } -#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -1522,50 +1509,6 @@ struct test_flash_attn_ext : public test_case { } }; -#ifdef TMP_ATTN_BENCH -// ATTN -struct test_attn : public test_case { - const int64_t hs; // head size - const int64_t nh; // num heads - const int64_t kv; // kv size - const int64_t nb; // batch size - - std::string op_desc(ggml_tensor * t) override { - return "ATTN"; - - GGML_UNUSED(t); - } - - std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); - } - - double max_nmse_err() override { - return 5e-4; - } - - test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); - - struct ggml_tensor * cur; - - cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); - cur = ggml_mul_mat (ctx, v, cur); - cur = ggml_permute (ctx, cur, 0, 2, 1, 3); - cur = ggml_cont_2d (ctx, cur, hs*nh, nb); - - return cur; - } -}; -#endif - enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2230,18 +2173,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); -#ifdef TMP_ATTN_BENCH - for (int hs : { 128, 256, 64, 80, }) { - for (int nh : { 32, }) { - for (int kv : { 512, 1024, 2048, 4096, }) { - for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { - test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); - } - } - } - } -#else for (int hs : { 64, 80, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { @@ -2251,7 +2182,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } -#endif // these tests are disabled to save execution time, but they can be handy for debugging #if 0 From 1fd5bc3d5e4ebfad3499d59dfee60202a4b7bb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 18:18:13 +0300 Subject: [PATCH 112/121] llama : support save/load state with FA enabled ggml-ci --- ci/run.sh | 3 ++- llama.cpp | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index da05f0d48802a..56beaedea0408 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -517,7 +517,8 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" diff --git a/llama.cpp b/llama.cpp index 718d5cccb6a9c..65ac6f6f26a8c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2036,8 +2036,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2335,11 +2335,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, ggml_type type_k, ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2350,6 +2353,7 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; // TODO: support mixed reccurent Transformer architectues // NOTE: (!a || b) is a logical implication (a -> b) @@ -15550,7 +15554,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -16330,7 +16334,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16486,7 +16490,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); From ac1c6d91de2d77a39f67ce55fd1ef6772d7e4a4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:03:59 +0300 Subject: [PATCH 113/121] ci : add CUDA save-load-state tests ggml-ci --- ci/run.sh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 56beaedea0408..fda1169b2c2f2 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -517,8 +518,10 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/save-load-state --model -fa ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model -fa -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" From c225609f1003612798698bea8b1a4d6e8d0c3da8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:37:27 +0300 Subject: [PATCH 114/121] llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci --- llama.cpp | 147 +++++++++++++++++++++++++++++++++++++----------------- llama.h | 2 +- 2 files changed, 103 insertions(+), 46 deletions(-) diff --git a/llama.cpp b/llama.cpp index eaf1d60b43bca..6db0b95710c0d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2566,6 +2566,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -16483,6 +16487,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -16516,8 +16522,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -16777,28 +16781,48 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } - // For the values, they are transposed, so we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (int il = 0; il < (int)n_layer; ++il) { - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - data_ctx.write(&v_type_i, sizeof(v_type_i)); + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write key type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write element size - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - data_ctx.write(&v_size_el, sizeof(v_size_el)); + // Write row size of key + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - tmp_buf.resize(range_size * v_size_el); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + tmp_buf.resize(range_size * v_size_row); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } } return data_ctx.get_size_written(); @@ -16923,41 +16947,74 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } - // For each layer, read the values for each cell (transposed) - for (int il = 0; il < (int)n_layer; ++il) { - // Read type of value - int32_t v_type_i_ref; - memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); - inp += sizeof(v_type_i_ref); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return 0; - } + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of key + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } - // Read element size of value - size_t v_size_el_ref; - memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); - inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); - return 0; - } + // Read row size of key + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } } const size_t nread = inp - src; + return nread; } diff --git a/llama.h b/llama.h index 792ef74d364ca..bedbc7c2c6e30 100644 --- a/llama.h +++ b/llama.h @@ -526,7 +526,7 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); From bab346ba69de6117bc37a8604ffe95ecaed84664 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:45:36 +0300 Subject: [PATCH 115/121] llama : fix copy-paste errors, add TODO --- llama.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 6db0b95710c0d..ecdc4b7fcf4fc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16781,13 +16781,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } + // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { - // Write key type + // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write row size of key + // Write row size of value const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); data_ctx.write(&v_size_row, sizeof(v_size_row)); @@ -16947,32 +16948,33 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } + // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { - // Read type of key + // Read type of value int32_t v_type_i_ref; memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } - // Read row size of key + // Read row size of value size_t v_size_row_ref; memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); inp += sizeof(v_size_row_ref); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); return 0; } if (cell_count) { - // Read and set the keys for the whole cell range + // Read and set the values for the whole cell range ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); inp += cell_count * v_size_row; } From 0fc5c5eb74accef6a5904e4933e2d3b08e3cd34a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 19:53:57 +0300 Subject: [PATCH 116/121] llama : disallow incompatible states --- llama.cpp | 6 ++++++ llama.h | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ecdc4b7fcf4fc..8b258f988e514 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16323,11 +16323,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -16473,11 +16475,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state diff --git a/llama.h b/llama.h index bedbc7c2c6e30..9c7cdf99faa6d 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 From 1e590ac3c97534ba0ff34388a30d2430a7684c10 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:06:23 +0300 Subject: [PATCH 117/121] llama : update llama_state_get_size after v_trans field --- llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama.cpp b/llama.cpp index 8b258f988e514..f41fba6c7f836 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16157,6 +16157,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -16174,6 +16175,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); From 4f4c0249bf31e3b9c161670231d32e60293a3314 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:29:25 +0300 Subject: [PATCH 118/121] metal : remove tmp log --- ggml-metal.m | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1903791f1f9e3..249b8312cdf63 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -463,9 +463,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ - GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ [metal_library release]; \ From 9e3876061c0734b6ec55326ea9d73f379594d9bd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 20:33:36 +0300 Subject: [PATCH 119/121] llama : add static reminder for llama_state_get_size --- llama.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama.cpp b/llama.cpp index f41fba6c7f836..3598921668767 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16180,6 +16180,9 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } From e180fcd3d53cee6ed5ac82473b45c468379a0687 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 11:04:32 +0300 Subject: [PATCH 120/121] metal : fix max nsg ggml-ci --- ggml-metal.m | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 249b8312cdf63..c6d580b8462d0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2643,13 +2643,25 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; From c240ae234c934fe7777aa4617d332b67bdb846c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 11:43:36 +0300 Subject: [PATCH 121/121] ci : fix arg order ggml-ci --- ci/run.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index d72e362898f15..bf21b6b31c52d 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -518,10 +518,10 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/save-load-state --model -fa -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/save-load-state --model -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/save-load-state --model -fa -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1"