From 40e717263e72e3af1f83b2a92499615fd391b0a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 Nov 2024 20:41:42 +0200 Subject: [PATCH 1/3] metal : minor fixup in FA kernel ggml-ci --- ggml/src/ggml-metal.metal | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index defde6246f129..6fc114e29f66c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - float4 mq[D4]; + float4 mq[D4/NW]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = (float4) sq4[i]; + mq[i/NW] = (float4) sq4[i]; } // pointer to the mask @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)]; - mqk += (float4) (mq[i] * mk); + mqk += (float4) (mq[i/NW] * mk); } // reduce the results from the threads in the simdgroup From fd7d5e870d7f09308a07f835f3b73c6a166a3ebb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Nov 2024 10:02:53 +0200 Subject: [PATCH 2/3] metal : use the unrolled loop variable --- ggml/src/ggml-metal.metal | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 6fc114e29f66c..fed076f7f052f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2780,7 +2780,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i/NW] = (float4) sq4[i]; + mq[ii/NW] = (float4) sq4[i]; } // pointer to the mask @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)]; - mqk += (float4) (mq[i/NW] * mk); + mqk += (float4) (mq[ii/NW] * mk); } // reduce the results from the threads in the simdgroup @@ -2858,7 +2858,7 @@ kernel void kernel_flash_attn_ext_vec_f16( #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[i/NW] *= ms; + lo[ii/NW] *= ms; } } @@ -2872,10 +2872,10 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; } } } From 909cfd498c122feec1bb428709fc24963e06b0ed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Nov 2024 11:24:18 +0200 Subject: [PATCH 3/3] metal : remove unused var --- ggml/src/ggml-metal.metal | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index fed076f7f052f..57eb34f13ac85 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2857,7 +2857,6 @@ kernel void kernel_flash_attn_ext_vec_f16( // O = diag(ms)*O #pragma unroll for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; lo[ii/NW] *= ms; } }