diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py new file mode 100644 index 0000000000000..198d40a155ccb --- /dev/null +++ b/tests/kernels/test_awq_triton.py @@ -0,0 +1,169 @@ +"""Tests for the AWQ Triton kernel. + +Run `pytest tests/kernels/test_awq_triton.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.quantization.awq_triton import ( + AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + +device = "cuda" + + +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int) -> torch.Tensor: + + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], + shifts[None, None, :]).to(torch.int8) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024]) +@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +def test_dequantize(qweight_rows, qweight_cols, group_size): + + if group_size == -1: + group_size = qweight_rows + + qweight_dtype = torch.int32 + scales_rows = qweight_rows // group_size + scales_cols = qweight_cols * 8 + scales_dtype = torch.float16 + zeros_rows = scales_rows + zeros_cols = qweight_cols + zeros_dtype = torch.int32 + + torch.manual_seed(0) + + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device) + scales = torch.rand(scales_rows, + scales_cols, + dtype=scales_dtype, + device=device) + zeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device) + + iweights_triton = awq_dequantize_triton(qweight, scales, zeros) + + assert (not torch.any(torch.isinf(iweights_triton)) + and not torch.any(torch.isnan(iweights_triton))) + + iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) + + torch.testing.assert_close(iweights_triton, iweights_torch) + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("M", [16, 24, 32]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("splitK", [1, 8]) +def test_gemm(N, K, M, splitK, group_size): + + if group_size == -1: + group_size = K + + split_k_iters = splitK + + input_rows = N + input_cols = K + input_dtype = torch.float32 + qweight_rows = input_cols + qweight_cols = M // 8 + scales_rows = qweight_rows // group_size + scales_cols = M + scales_dtype = torch.float32 + qzeros_rows = scales_rows + qzeros_cols = qweight_cols + + torch.manual_seed(0) + + input = torch.rand((input_rows, input_cols), + dtype=input_dtype, + device=device) + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + device=device) + qzeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (qzeros_rows, qzeros_cols), + device=device) + scales = torch.rand((scales_rows, scales_cols), + dtype=scales_dtype, + device=device) + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, + split_k_iters) + + assert (not torch.any(torch.isinf(output_triton)) + and not torch.any(torch.isnan(output_triton))) + + dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) + + output_torch = torch.matmul(input, dequantized_weights) + + assert (not torch.any(torch.isinf(output_torch)) + and not torch.any(torch.isnan(output_torch))) + + torch.testing.assert_close(output_triton.cpu(), + output_torch.cpu(), + atol=1e-1, + rtol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ae90af563c0cf..e5e7bb6963973 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm._core_ext import ScalarType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) diff --git a/vllm/config.py b/vllm/config.py index 4e014e43d849a..0a34dabf57e7c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -267,7 +267,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", @@ -322,6 +322,12 @@ def _verify_quantization(self) -> None: "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) + if (self.quantization == "awq" and is_hip() + and not envs.VLLM_USE_TRITON_AWQ): + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: diff --git a/vllm/envs.py b/vllm/envs.py index 24e09ee0e055f..4faafd9daf304 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -400,6 +400,10 @@ def get_default_config_root(): "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000000000..ad706f28a742b --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,304 @@ +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Dequantize b. + offsets_szk = ( + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + + tl.arange(0, BLOCK_SIZE_K) // group_size) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + return result