Skip to content

Commit

Permalink
[Kernel] [Triton] [AMD] Adding Triton implementations awq_dequantize …
Browse files Browse the repository at this point in the history
…and awq_gemm to support AWQ (vllm-project#7386)
  • Loading branch information
rasmith authored Aug 28, 2024
1 parent b98cc28 commit e5697d1
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 1 deletion.
169 changes: 169 additions & 0 deletions tests/kernels/test_awq_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Tests for the AWQ Triton kernel.
Run `pytest tests/kernels/test_awq_triton.py`.
"""
import pytest
import torch

from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)

device = "cuda"


def reverse_awq_order(t: torch.Tensor):
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
t.shape[-1],
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)

t = t[:, reverse_order_tensor] & 0xF
return t


# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int) -> torch.Tensor:

if group_size == -1:
group_size = qweight.shape[0]

bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)

iweights = torch.bitwise_right_shift(qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)

iweights = iweights.view(iweights.shape[0], -1)

zeros = torch.bitwise_right_shift(qzeros[:, :, None],
shifts[None, None, :]).to(torch.int8)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)

iweights = reverse_awq_order(iweights)

iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)

scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales


# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
def test_dequantize(qweight_rows, qweight_cols, group_size):

if group_size == -1:
group_size = qweight_rows

qweight_dtype = torch.int32
scales_rows = qweight_rows // group_size
scales_cols = qweight_cols * 8
scales_dtype = torch.float16
zeros_rows = scales_rows
zeros_cols = qweight_cols
zeros_dtype = torch.int32

torch.manual_seed(0)

qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device)

iweights_triton = awq_dequantize_triton(qweight, scales, zeros)

assert (not torch.any(torch.isinf(iweights_triton))
and not torch.any(torch.isnan(iweights_triton)))

iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)

torch.testing.assert_close(iweights_triton, iweights_torch)


# input - [N, K]
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
@pytest.mark.parametrize("K", [128])
@pytest.mark.parametrize("M", [16, 24, 32])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("splitK", [1, 8])
def test_gemm(N, K, M, splitK, group_size):

if group_size == -1:
group_size = K

split_k_iters = splitK

input_rows = N
input_cols = K
input_dtype = torch.float32
qweight_rows = input_cols
qweight_cols = M // 8
scales_rows = qweight_rows // group_size
scales_cols = M
scales_dtype = torch.float32
qzeros_rows = scales_rows
qzeros_cols = qweight_cols

torch.manual_seed(0)

input = torch.rand((input_rows, input_cols),
dtype=input_dtype,
device=device)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
device=device)
qzeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(qzeros_rows, qzeros_cols),
device=device)
scales = torch.rand((scales_rows, scales_cols),
dtype=scales_dtype,
device=device)

output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
split_k_iters)

assert (not torch.any(torch.isinf(output_triton))
and not torch.any(torch.isnan(output_triton)))

dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)

output_torch = torch.matmul(input, dequantized_weights)

assert (not torch.any(torch.isinf(output_torch))
and not torch.any(torch.isnan(output_torch)))

torch.testing.assert_close(output_triton.cpu(),
output_torch.cpu(),
atol=1e-1,
rtol=1e-1)
9 changes: 9 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand Down Expand Up @@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_gemm_triton)
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


Expand Down
8 changes: 7 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
Expand Down Expand Up @@ -322,6 +322,12 @@ def _verify_quantization(self) -> None:
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def get_default_config_root():
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
}

# end-env-vars-definition
Expand Down
Loading

0 comments on commit e5697d1

Please sign in to comment.