Skip to content

Commit

Permalink
mtl : mul_mat fixes (still wrong)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 30, 2023
1 parent 2a24994 commit 96d0052
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 33 deletions.
30 changes: 14 additions & 16 deletions examples/mtl/mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -377,29 +377,27 @@ int llama_mtl_eval(
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];

const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];

const int64_t ncols = gf->nodes[i]->ne[0];
const int64_t nrows = gf->nodes[i]->ne[1];
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];

[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3];
[encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4];
[encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5];
[encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6];
[encoder setBytes:&ncols length:sizeof(ncols) atIndex:7];
[encoder setBytes:&nrows length:sizeof(nrows) atIndex:8];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:5];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];

printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows);
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1);

[encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_GET_ROWS:
{
Expand Down
31 changes: 14 additions & 17 deletions examples/mtl/mtl.metal
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,16 @@ kernel void kernel_mul_mat_q4_0(
sum[tpitg.x] = 0.0f;

for (int i = 0; i < nb; i += tptg.x) {
device const uint4 * x0p = (device const uint4 *) (x + i);
device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*qk);

const uint4 x0 = *x0p;

const uint4 x0l = x0 & uint4(0x0F0F0F0F);
const uint4 x0h = x0 >> 4;
const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;

const int4 x0ls = as_type<int4>(x0l) - int4(8);
const int4 x0hs = as_type<int4>(x0h) - int4(8);

thread const uchar * x0lsb = (thread const uchar *) &x0ls;
thread const uchar * x0hsb = (thread const uchar *) &x0hs;
thread const char * x0lsb = (thread const char *) &x0l;
thread const char * x0hsb = (thread const char *) &x0h;

const float4 y00 = *(y0p + 0);
const float4 y01 = *(y0p + 1);
Expand All @@ -167,17 +164,17 @@ kernel void kernel_mul_mat_q4_0(
const float4 y06 = *(y0p + 6);
const float4 y07 = *(y0p + 7);

const float d = (x + i)->d;
const half d = (x + i)->d;

sum[tpitg.x] += (
x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] +
x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] +
x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] +
x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] +
x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] +
x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] +
x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] +
x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3]
(x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
(x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
(x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
(x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
(x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
(x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
(x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
(x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
) * d;
}

Expand Down

0 comments on commit 96d0052

Please sign in to comment.