Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Kernel] Add IQ1_M quantization implementation to GGUF kernel #8357

Merged
merged 18 commits into from
Sep 15, 2024
Merged
55 changes: 46 additions & 9 deletions csrc/quantization/gguf/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const int64_t i = blockIdx.x;
const block_iq1_s * x = (const block_iq1_s *) vx;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int64_t i = blockIdx.x;
const block_iq1_m * x = (const block_iq1_m *) vx;

const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const int i8 = 4*ib+il;
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
const float d = __half2float(x[i].d) * (2*(h & 7) + 1);
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]);
const uint16_t * sc = (const uint16_t *)x[i].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
}
}

template<typename dst_t>
Expand Down Expand Up @@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
Expand Down Expand Up @@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
return dequantize_row_iq2_s_cuda;
case 23:
return dequantize_row_iq4_xs_cuda;
case 29:
return dequantize_row_iq1_m_cuda;
default:
return nullptr;
}
Expand Down
408 changes: 277 additions & 131 deletions csrc/quantization/gguf/ggml-common.h

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
}
return Y;
}
Expand Down
8 changes: 8 additions & 0 deletions csrc/quantization/gguf/mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half *
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}

static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}

static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
Expand Down
101 changes: 81 additions & 20 deletions csrc/quantization/gguf/vecdotq.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment

int x32 = x16[2*i32 + 0] << 0;
x32 |= x16[2*i32 + 1] << 16;

return x32;
}

static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}

static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0;
Expand Down Expand Up @@ -1658,28 +1671,76 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(

static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;

const int ib32 = iqs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
const uint8_t h1 = bq1->scales[2*ib32+0];
const uint8_t h2 = bq1->scales[2*ib32+1];
const int * q8 = (const int *)bq8_1[ib32].qs;
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 2; ++j) {
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
}
const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds);
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
#endif
const int qs_packed = get_int_b2(bq1->qs, iqs);
const uint8_t * qs = (const uint8_t *) &qs_packed;

const int qh = bq1->qh[iqs];

int sumi = 0;
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];

const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;

const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);

sumi = __dp4a(grid0, u0, sumi);
sumi = __dp4a(grid1, u1, sumi);
}

const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
const float2 ds = __half22float2(bq8_1[iqs].ds);
return d1q * (ds.x*sumi + ds.y*delta);
}

static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

const block_iq1_m * bq1 = (const block_iq1_m *) vbq;

const int qs_packed = get_int_b4(bq1->qs, iqs);
const uint8_t * qs = (const uint8_t *) &qs_packed;

int sumi[2] = {0};
float sumf[2] = {0.0f};
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));

const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];

const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;

const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);

sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]);
sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]);

const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
int sumy = 0;
sumy = __dp4a(u0, 0x01010101, sumy);
sumy = __dp4a(u1, 0x01010101, sumy);
sumf[l0/4] += delta*sumy;
}

const uint16_t * sc = (const uint16_t *) bq1->scales;

iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);

const int tmp = sc[iqs/2] >> (6*(iqs%2));
const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}

static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
partial-json-parser # used for parsing partial JSON outputs
pyzmq
msgspec
gguf == 0.9.1
gguf == 0.10.0
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
importlib_metadata
mistral_common >= 1.4.0
pyyaml
Expand Down
126 changes: 126 additions & 0 deletions tests/kernels/test_gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from pathlib import Path
from typing import List

import pytest
import torch
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download

import vllm._custom_ops as ops

GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")


def get_gguf_sample_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> List[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved


DTYPES = [torch.half]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024]
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
SEEDS = [0]
QUANT_TYPES = [
# i-matrix
GGMLQuantizationType.IQ1_M,
GGMLQuantizationType.IQ1_S,
GGMLQuantizationType.IQ2_S,
GGMLQuantizationType.IQ2_XS,
GGMLQuantizationType.IQ3_S,
GGMLQuantizationType.IQ3_XXS,
GGMLQuantizationType.IQ4_NL,
GGMLQuantizationType.IQ4_XS,
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quantization
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
]


@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))

ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape)).to(dtype)

torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)


@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0)

tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
ref_output = x @ weight.T

qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)

torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0)

tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
ref_output = x @ weight.T

qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)

torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def get_scaled_act_names(self) -> List[str]:
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if qweight_type >= 16:
if x.shape[0] == 1:
# enable mmvq in contiguous batching
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
elif qweight_type >= 16:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
Expand Down
Loading