Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : alternative Q4_3 format + implementation #1108

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 84 additions & 78 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,14 @@ typedef struct {
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");

#define QK4_3 16
#define QK4_3 32
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t d0; // delta
ggml_fp16_t d1; // delta
ggml_fp16_t m; // min
uint8_t qs[QK4_3 / 2]; // nibbles / quants
} block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
static_assert(sizeof(block_q4_3) == 3 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");

#define QK8_0 32
typedef struct {
Expand Down Expand Up @@ -1244,32 +1245,46 @@ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * r
const int nb = k / QK4_3;

for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
float min0 = FLT_MAX;
float max0 = -FLT_MAX;
float min1 = FLT_MAX;
float max1 = -FLT_MAX;

for (int l = 0; l < QK4_3; l++) {
const float v = x[i*QK4_3 + l];
if (v < min) min = v;
if (v > max) max = v;
for (int l = 0; l < QK4_3/2; l++) {
const float v0 = x[i*QK4_3 + l];
const float v1 = x[i*QK4_3 + l + QK4_3/2];
if (v0 < min0) min0 = v0;
if (v0 > max0) max0 = v0;
if (v1 < min1) min1 = v1;
if (v1 > max1) max1 = v1;
}

const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
const float min = MIN(min0, min1);

y[i].d = GGML_FP32_TO_FP16(d);
const float d0 = (max0 - min) / ((1 << 4) - 1);
const float d1 = (max1 - min) / ((1 << 4) - 1);
const float id0 = d0 ? 1.0f/d0 : 0.0f;
const float id1 = d1 ? 1.0f/d1 : 0.0f;

y[i].d0 = GGML_FP32_TO_FP16(d0);
y[i].d1 = GGML_FP32_TO_FP16(d1);
y[i].m = GGML_FP32_TO_FP16(min);

for (int l = 0; l < QK4_3; l += 2) {
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
for (int l = 0; l < QK4_3/2; l += 2) {
const float v0_0 = (x[i*QK4_3 + l + 0] - min)*id0;
const float v0_1 = (x[i*QK4_3 + l + 1] - min)*id0;

const uint8_t vi0 = (int) (v0 + 0.5f);
const uint8_t vi1 = (int) (v1 + 0.5f);
const float v1_0 = (x[i*QK4_3 + l + QK4_3/2 + 0] - min)*id1;
const float v1_1 = (x[i*QK4_3 + l + QK4_3/2 + 1] - min)*id1;

assert(vi0 < 16);
assert(vi1 < 16);
const uint8_t vi0_0 = (int) (v0_0 + 0.5f);
const uint8_t vi0_1 = (int) (v0_1 + 0.5f);

y[i].qs[l/2] = vi0 | (vi1 << 4);
const uint8_t vi1_0 = (int) (v1_0 + 0.5f);
const uint8_t vi1_1 = (int) (v1_1 + 0.5f);

y[i].qs[l/2] = vi0_0 | (vi0_1 << 4);
y[i].qs[l/2 + QK4_3/4] = vi1_0 | (vi1_1 << 4);
}
}
}
Expand Down Expand Up @@ -1737,25 +1752,31 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
const block_q4_3 * restrict x = vx;

for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const float m = GGML_FP16_TO_FP32(x[i].m);
const float d0 = GGML_FP16_TO_FP32(x[i].d0);
const float d1 = GGML_FP16_TO_FP32(x[i].d1);
const float m = GGML_FP16_TO_FP32(x[i].m);

const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi0 = x[i].qs[ l/2];
const uint8_t vi1 = x[i].qs[QK4_3/4 + l/2];

for (int l = 0; l < QK4_3; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0_0 = vi0 & 0xf;
const int8_t vi0_1 = vi0 >> 4;

const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const int8_t vi1_0 = vi1 & 0xf;
const int8_t vi1_1 = vi1 >> 4;

const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
const float v0_0 = vi0_0*d0 + m;
const float v0_1 = vi0_1*d0 + m;

const float v1_0 = vi1_0*d1 + m;
const float v1_1 = vi1_1*d1 + m;

y[i*QK4_3 + l + 0] = v0;
y[i*QK4_3 + l + 1] = v1;
y[i*QK4_3 + l + 0] = v0_0;
y[i*QK4_3 + l + 1] = v0_1;

assert(!isnan(y[i*QK4_3 + l + 0]));
assert(!isnan(y[i*QK4_3 + l + 1]));
y[i*QK4_3 + QK4_3/2 + l + 0] = v1_0;
y[i*QK4_3 + QK4_3/2 + l + 1] = v1_1;
}
}
}
Expand Down Expand Up @@ -2632,35 +2653,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

// interleave
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);

// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);

// interleave
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));

const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));

const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
Expand Down Expand Up @@ -2931,7 +2952,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *

assert(n % QK8_0 == 0);
assert(nb % 2 == 0);
assert(QK8_0 == 2*QK4_2);
assert(QK8_0 == QK4_3);

const block_q4_3 * restrict x = vx;
const block_q8_0 * restrict y = vy;
Expand All @@ -2942,29 +2963,25 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i += 2) {
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
float summs = 0;

for (int i = 0; i < nb; i += 2) {
const block_q4_3 * restrict x0 = &x[i + 0];
const block_q4_3 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

const uint8x16_t m4b = vdupq_n_u8(0xf);
summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;

const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
const uint8x16_t m4b = vdupq_n_u8(0xf);

const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
const float x0_0d = GGML_FP16_TO_FP32(x0->d0);
const float x0_1d = GGML_FP16_TO_FP32(x0->d1);
const float x1_0d = GGML_FP16_TO_FP32(x1->d0);
const float x1_1d = GGML_FP16_TO_FP32(x1->d1);

const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);

// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
Expand All @@ -2984,17 +3001,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);

const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));

const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);

#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
Expand Down Expand Up @@ -3023,7 +3029,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif
}

sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs;
#else
// scalar
for (int i = 0; i < nb; i++) {
Expand Down