Skip to content

Commit

Permalink
metal : minor fixup in FA kernel (#10143)
Browse files Browse the repository at this point in the history
* metal : minor fixup in FA kernel

ggml-ci

* metal : use the unrolled loop variable

* metal : remove unused var
  • Loading branch information
ggerganov authored Nov 3, 2024
1 parent 1839f69 commit 08828a6
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;

// load the queries from shared memory into local memory
float4 mq[D4];
float4 mq[D4/NW];

for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[i] = (float4) sq4[i];
mq[ii/NW] = (float4) sq4[i];
}

// pointer to the mask
Expand Down Expand Up @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)];

mqk += (float4) (mq[i] * mk);
mqk += (float4) (mq[ii/NW] * mk);
}

// reduce the results from the threads in the simdgroup
Expand Down Expand Up @@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[i/NW] *= ms;
lo[ii/NW] *= ms;
}
}

Expand All @@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;

lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}
Expand Down

0 comments on commit 08828a6

Please sign in to comment.