diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e982518377d0..bf00a36edc500 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 6541b4d46d7f6..9ef1fcb465bf3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp); torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits); + torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, diff --git a/csrc/quantization/fp8/fp8_marlin.cu b/csrc/quantization/fp8/fp8_marlin.cu index 51ff071987f80..eef6dc6ebdf4a 100644 --- a/csrc/quantization/fp8/fp8_marlin.cu +++ b/csrc/quantization/fp8/fp8_marlin.cu @@ -19,10 +19,10 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "../gptq_marlin/gptq_marlin.cuh" -#include "../gptq_marlin/gptq_marlin_dtypes.cuh" +#include "../gptq_marlin/marlin.cuh" +#include "../gptq_marlin/marlin_dtypes.cuh" -using namespace gptq_marlin; +using namespace marlin; #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, num_groups = b_scales.size(0); // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); @@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { fp8_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + marlin::max_par); } else { TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu new file mode 100644 index 0000000000000..c58216d8e00c5 --- /dev/null +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -0,0 +1,269 @@ +#include "marlin.cuh" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} + +} // namespace marlin + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +namespace marlin { + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin + + #define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK(b_q_weight.size(0) == size_k, + "b_q_weight.size(0) = ", b_q_weight.size(0), + " is not size_k = ", size_k); + TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_n = ", size_n, ", pack_factor = ", pack_factor); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); + + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4) + CALL_IF(8) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + } + + return out; +} + +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 0beb9de14c687..122c5c16b58ce 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -19,8 +19,8 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "gptq_marlin.cuh" -#include "gptq_marlin_dtypes.cuh" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -32,7 +32,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace gptq_marlin { +namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -72,10 +72,11 @@ __global__ void Marlin( } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -264,6 +265,114 @@ dequant_8bit(int q) { return frag_b; } +// Zero-point dequantizers + +template +__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit_zp( + int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit_zp(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template +__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit_zp( + int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit_zp(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template @@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType::FragB& frag_b, frag_b[1] = __hmul2(frag_b[1], s); } +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + // Same as above, but for act_order (each K is multiplied individually) template __device__ inline void scale4(typename ScalarType::FragB& frag_b, @@ -404,6 +524,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -413,6 +534,8 @@ __global__ void Marlin( int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m @@ -437,6 +560,7 @@ __global__ void Marlin( using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; constexpr int pack_factor = 32 / num_bits; @@ -566,6 +690,13 @@ __global__ void Marlin( int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -605,6 +736,19 @@ __global__ void Marlin( int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. @@ -616,6 +760,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. @@ -664,14 +820,17 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { @@ -777,6 +936,28 @@ __global__ void Marlin( } } } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -784,6 +965,12 @@ __global__ void Marlin( cp_async_fence(); }; + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering @@ -932,8 +1119,73 @@ __global__ void Marlin( } }; + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + if constexpr (!has_zp) { + return; + } + + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + if constexpr (num_bits == 4) { + int zp_quant = frag_qzp[k % 2][0]; + int zp_quant_shift = zp_quant >> 8; + frag_zp_0 = dequant_4bit_zp(zp_quant); + frag_zp_1 = dequant_4bit_zp(zp_quant_shift); + + } else { + int zp_quant_0 = frag_qzp[k % 2][0]; + int zp_quant_1 = frag_qzp[k % 2][1]; + frag_zp_0 = dequant_8bit_zp(zp_quant_0); + frag_zp_1 = dequant_8bit_zp(zp_quant_1); + } + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -944,16 +1196,32 @@ __global__ void Marlin( int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + if constexpr (has_zp) { + frag_b0 = dequant_4bit_zp(b_quant); + frag_b1 = dequant_4bit_zp(b_quant_shift); + + } else { + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + } } else { int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + if constexpr (has_zp) { + frag_b0 = dequant_8bit_zp(b_quant_0); + frag_b1 = dequant_8bit_zp(b_quant_1); + } else { + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + } + + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); } // Apply scale to frag_b0 @@ -967,6 +1235,11 @@ __global__ void Marlin( } } + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], @@ -1189,6 +1462,12 @@ __global__ void Marlin( } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } fetch_to_shared(i, i, i < slice_iters); } @@ -1197,6 +1476,7 @@ __global__ void Marlin( init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; @@ -1217,6 +1497,7 @@ __global__ void Marlin( for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -1354,6 +1635,7 @@ __global__ void Marlin( } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } start_pipes(); @@ -1363,22 +1645,24 @@ __global__ void Marlin( } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin, \ + HAS_ZP, GROUP_BLOCKS>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ + HAS_ZP, GROUP_BLOCKS> \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ + prob_m, prob_n, prob_k, locks); \ } typedef struct { @@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + + #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, int max_par) { + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; @@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, thread_m_blocks = exec_cfg.max_m_blocks; } - // Define kernel configurations if (false) { } - CALL_IF(4, 32, 2, 256) - CALL_IF(4, 16, 4, 256) - CALL_IF(4, 8, 8, 256) - CALL_IF(4, 8, 4, 128) - CALL_IF(4, 4, 8, 128) - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) + GPTQ_CALL_IF(4, 16, 4, 256) + GPTQ_CALL_IF(4, 8, 8, 256) + GPTQ_CALL_IF(4, 8, 4, 128) + GPTQ_CALL_IF(4, 4, 8, 128) + GPTQ_CALL_IF(8, 16, 4, 256) + GPTQ_CALL_IF(8, 8, 8, 256) + GPTQ_CALL_IF(8, 8, 4, 128) + GPTQ_CALL_IF(8, 4, 8, 128) + + AWQ_CALL_IF(4, 16, 4, 256) + AWQ_CALL_IF(4, 8, 8, 256) + AWQ_CALL_IF(4, 8, 4, 128) + AWQ_CALL_IF(4, 4, 8, 128) + AWQ_CALL_IF(8, 16, 4, 256) + AWQ_CALL_IF(8, 8, 8, 256) + AWQ_CALL_IF(8, 8, 4, 128) + AWQ_CALL_IF(8, 4, 8, 128) else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; @@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", marlin::tile_size); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + " is not divisible by tile_size = ", marlin::tile_size); + int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); @@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int group_size = -1; bool has_act_order = g_idx.size(0) != 0; - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); num_groups = b_scales.size(0); @@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_scales.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, gptq_marlin::max_par); + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - gptq_marlin::marlin_mm_f16i4( + marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, - is_k_full, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 4adc158eb14ea..c71b1bf573263 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -1,23 +1,16 @@ -#include "gptq_marlin.cuh" - -namespace gptq_marlin { - -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; +#include "marlin.cuh" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, @@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, #else +namespace marlin { + template -__global__ void marlin_repack_kernel( +__global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { @@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel( } } -} // namespace gptq_marlin - - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ +} // namespace marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); - TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", marlin::tile_k_size); + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = - torch::empty({size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / pack_factor}, - options); + torch::Tensor out = torch::empty( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh similarity index 88% rename from csrc/quantization/gptq_marlin/gptq_marlin.cuh rename to csrc/quantization/gptq_marlin/marlin.cuh index 42af44951efda..74ccbac57bd3c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -9,7 +9,9 @@ #include #include -namespace gptq_marlin { +namespace marlin { + +// Marlin params // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, @@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers template struct Vec { T elems[n]; @@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh similarity index 93% rename from csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh rename to csrc/quantization/gptq_marlin/marlin_dtypes.cuh index ca1b7099d6ec7..be06c09bee331 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -1,11 +1,11 @@ #ifndef _data_types_cuh #define _data_types_cuh -#include "gptq_marlin.cuh" +#include "marlin.cuh" #include #include -namespace gptq_marlin { +namespace marlin { template class ScalarType {}; @@ -23,6 +23,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); @@ -51,6 +52,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragZP = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { @@ -72,6 +74,6 @@ class ScalarType { #endif }; -} // namespace gptq_marlin +} // namespace marlin #endif diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index d124c0149912d..37339b84ae25b 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -30,7 +30,7 @@ inline std::string str(T x) { return std::to_string(x); } -namespace marlin { +namespace marlin_dense { constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, } } -} // namespace marlin +} // namespace marlin_dense torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(size_k == a.size(1), "Shape mismatch: a.size(1) = " + str(a.size(1)) + ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin::tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(marlin::tile_size)); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK(size_k % marlin_dense::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_dense::tile_size)); + TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin::tile_size)); + ", tile_size = " + str(marlin_dense::tile_size)); // Verify N TORCH_CHECK(b_scales.size(1) == size_n, "b_scales.size(1) = " + str(b_scales.size(1)) + ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin::tile_size)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_dense::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - int actual_size_n = - (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; + int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * + marlin_dense::pack_factor_4bit; TORCH_CHECK( size_n == actual_size_n, "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); @@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "Unexpected groupsize = " + str(groupsize)); // Verify workspace size - TORCH_CHECK( - size_n % marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_dense::min_thread_n)); + int min_workspace_size = + (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = " + str(workspace.numel()) + " is below min_workspace_size = " + str(min_workspace_size)); int dev = a.get_device(); - marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, marlin::max_par); + marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, marlin_dense::max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d5136e45e781e..0df9bdb75018f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -141,6 +141,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_marlin_repack", &gptq_marlin_repack); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + // awq_marlin repack from AWQ. + ops.def("awq_marlin_repack", &awq_marlin_repack); + ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 3bd6680cf8134..42087fdcce959 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -12,16 +12,18 @@ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, - marlin_permute_scales) + MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, + marlin_make_empty_g_idx, marlin_permute_scales) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( pack_fp8_to_int32) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights) + MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, + marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, + sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -57,12 +59,12 @@ def rand_data(shape, dtype=torch.float16): reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, - mnk_factors): +def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, + mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, + mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Create input + b_weight = rand_data((size_k, size_n)) + + # Quantize + w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits, + group_size) + + # Pack to AWQ format + q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + + # Pack to Marlin format + weight_perm = get_weight_perm(num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + + # Run Marlin repack GPU kernel + marlin_q_w_2 = ops.awq_marlin_repack( + q_w_awq, + size_k, + size_n, + num_bits, + ) + torch.cuda.synchronize() + + assert torch.allclose(marlin_q_w_1, marlin_q_w_2) + + +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) -def test_marlin_gemm( +def test_gptq_marlin_gemm( k_chunk, n_chunk, num_bits, @@ -155,6 +205,8 @@ def test_marlin_gemm( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, num_bits, group_size, act_order) + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) @@ -162,6 +214,7 @@ def test_marlin_gemm( a_input, marlin_q_w, marlin_s, + marlin_zp, g_idx, sort_indices, workspace.scratch, @@ -170,6 +223,7 @@ def test_marlin_gemm( b_weight.shape[1], a_input.shape[1], is_k_full, + has_zp=False, ) output_ref = torch.matmul(a_input, w_ref) @@ -188,7 +242,8 @@ def test_marlin_gemm( @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, + mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -301,3 +356,65 @@ def test_fp8_marlin_gemm( print("max_diff = {}".format(max_diff)) assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_awq_marlin_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, num_bits, group_size) + + g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) + sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) + is_k_full = True + has_zp = True + + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + marlin_zp, + g_idx, + sort_indices, + workspace.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + has_zp, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index b63a8d01d6621..d18233fe1aeae 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -44,9 +44,9 @@ class ModelPair: ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), # AUTOAWQ - ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), - ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "awq_marlin"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"), ] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 80ca357e8b293..e5151c070f2f7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -276,14 +276,22 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits) +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, size_k: int, - is_k_full: bool) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) + b_scales: torch.Tensor, b_zeros: torch.Tensor, + g_idx: torch.Tensor, perm: torch.Tensor, + workspace: torch.Tensor, num_bits: int, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + has_zp: bool) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, + g_idx, perm, workspace, num_bits, + size_m, size_n, size_k, is_k_full, + has_zp) # fp8 marlin diff --git a/vllm/config.py b/vllm/config.py index 46528a548de1e..9d60f07579217 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -251,7 +251,7 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if (self.quantization not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "fbgemm_fp8", "compressed_tensors")): + "awq_marlin", "fbgemm_fp8", "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index c1bb45224fcc1..bd574512e3431 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -2,6 +2,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.bitsandbytes import ( @@ -31,6 +32,7 @@ "marlin": MarlinConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py new file mode 100644 index 0000000000000..092f87b623e7f --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -0,0 +1,268 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_awq_marlin_linear, awq_to_marlin_zero_points, + check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, replace_tensor, verify_awq_marlin_supported, + verify_marlin_supports_shape) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + +logger = init_logger(__name__) + + +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + def __init__(self, weight_bits: int, group_size: int, has_zp: bool, + lm_head_quantized: bool) -> None: + self.weight_bits = weight_bits + self.pack_factor = 32 // self.weight_bits # packed into int32 + self.group_size = group_size + self.has_zp = has_zp + self.lm_head_quantized = lm_head_quantized + + verify_awq_marlin_supported(num_bits=self.weight_bits, + group_size=self.group_size, + has_zp=self.has_zp) + + def __repr__(self) -> str: + return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"has_zp={self.has_zp}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + has_zp = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, has_zp, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info("Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AWQMarlinLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return AWQMarlinLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + has_zp = quant_config.get("zero_point", None) + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or has_zp is None): + return False + + return check_awq_marlin_supported( + num_bits=num_bits, + group_size=group_size, + has_zp=has_zp, + min_capability=cls.get_min_capability()) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) + + qweight = Parameter( + torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + + num_groups = input_size_per_partition // group_size + + qzeros = Parameter( + torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + + scales = Parameter( + torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": 0, + "output_dim": 1, + }) + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_tensor(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_awq_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + weight_zp=layer.qzeros, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.quant_config.weight_bits, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 3f3febcad4f85..e4cf0c0b5d95b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -7,8 +7,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, + apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs @@ -38,9 +38,9 @@ def __init__(self, self.group_size = group_size # Verify supported on platform. - verify_marlin_supported(num_bits=self.num_bits, - group_size=self.group_size, - is_sym=True) + verify_gptq_marlin_supported(num_bits=self.num_bits, + group_size=self.group_size, + is_sym=True) def get_min_capability(self) -> int: # ampere and up @@ -135,6 +135,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # No zero-point + layer.weight_zp = marlin_make_empty_g_idx(device) + # Repack weights from compressed-tensors format to marlin format. marlin_qweight = ops.gptq_marlin_repack( layer.weight_packed.t().contiguous(), @@ -155,10 +158,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return apply_marlin_linear( + return apply_gptq_marlin_linear( input=x, weight=layer.weight_packed, weight_scale=layer.weight_scale, + weight_zp=layer.weight_zp, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bb9644dbc9947..5b4d614ae2e74 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,10 +10,10 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_marlin_linear, check_marlin_supported, marlin_is_k_full, + apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + verify_gptq_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) @@ -37,9 +37,9 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.lm_head_quantized = lm_head_quantized # Verify supported on platform. - verify_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - is_sym=self.is_sym) + verify_gptq_marlin_supported(num_bits=self.weight_bits, + group_size=self.group_size, + is_sym=self.is_sym) def __repr__(self) -> str: return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " @@ -77,7 +77,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - can_convert = cls.is_marlin_compatible(hf_quant_cfg) + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = (user_quant is None or user_quant == "marlin") @@ -105,22 +105,27 @@ def get_scaled_act_names(self) -> List[str]: return [] @classmethod - def is_marlin_compatible(cls, quant_config: Dict[str, Any]): + def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits", None) group_size = quant_config.get("group_size", None) sym = quant_config.get("sym", None) desc_act = quant_config.get("desc_act", None) + if quant_method != "gptq": + return False + # If we cannot find the info needed in the config, cannot convert. if (num_bits is None or group_size is None or sym is None or desc_act is None): return False - return check_marlin_supported(num_bits=num_bits, - group_size=group_size, - is_sym=sym, - min_capability=cls.get_min_capability()) + return check_gptq_marlin_supported( + num_bits=num_bits, + group_size=group_size, + is_sym=sym, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -278,6 +283,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # No zero-point + layer.zp = marlin_make_empty_g_idx(device) + # Repack weights from autogptq format to marlin format. marlin_qweight = ops.gptq_marlin_repack( layer.qweight, @@ -302,10 +310,11 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_marlin_linear( + return apply_gptq_marlin_linear( input=x, weight=layer.qweight, weight_scale=layer.scales, + weight_zp=layer.zp, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 764f0a6f3b71c..25a7cd7bde653 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,54 +1,92 @@ from typing import List, Optional, Tuple +import numpy import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform +from .quant_utils import pack_cols, unpack_cols + GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -GPTQ_MARLIN_SUPPORTED_SYM = [True] -GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1] - - -def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: int) -> bool: - - # If the capability of the device is too low, cannot convert. - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < min_capability: - return False - - return (device_capability >= min_capability - and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES - and is_sym in GPTQ_MARLIN_SUPPORTED_SYM) - - -def verify_marlin_supported(num_bits: int, group_size: Optional[int], - is_sym: bool) -> None: - - if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - raise ValueError( - f"Marlin does not support weight_bits = {num_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " - "are supported.") - if (group_size is None - or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES): - raise ValueError( - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") - if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: - raise ValueError( - f"Marlin does not support is_sym = is_sym. " - f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") +MARLIN_SUPPORTED_NUM_BITS = [4, 8] +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, + min_capability: Optional[int], + has_zp: bool) -> Tuple[bool, Optional[str]]: + if min_capability is not None: + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + if device_capability < min_capability: + return (False, "Marlin does not support device_capability = {}" + ", the min_capability required is {}".format( + device_capability, min_capability)) + + if num_bits not in MARLIN_SUPPORTED_NUM_BITS: + return (False, "Marlin does not support weight_bits = {}. " + "Only weight_bits = {} are supported.".format( + num_bits, MARLIN_SUPPORTED_NUM_BITS)) + + if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return (False, "Marlin does not support group_size = {}. Only " + "group_sizes = {} are supported.".format( + group_size, MARLIN_SUPPORTED_GROUP_SIZES)) + + if not has_zp and not is_sym: + return (False, + "Marlin without zero_points must have symmetric quantization") + + return True, None + + +def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool, + min_capability: int) -> bool: + cond, _ = _check_marlin_supported(num_bits, + group_size, + is_sym, + min_capability, + has_zp=False) + return cond + + +def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool, + min_capability: int) -> bool: + cond, _ = _check_marlin_supported(num_bits, + group_size, + False, + min_capability, + has_zp=has_zp) + return cond + + +def verify_gptq_marlin_supported(num_bits: int, group_size: int, + is_sym: bool) -> None: + cond, err_msg = _check_marlin_supported(num_bits, + group_size, + is_sym, + min_capability=None, + has_zp=False) + if not cond: + assert err_msg is not None + raise ValueError("GPTQ" + err_msg) + + +def verify_awq_marlin_supported(num_bits: int, group_size: int, + has_zp: bool) -> None: + cond, err_msg = _check_marlin_supported(num_bits, + group_size, + False, + min_capability=None, + has_zp=has_zp) + if not cond: + assert err_msg is not None + raise ValueError("AWQ" + err_msg) def verify_marlin_supports_shape(output_size_per_partition: int, @@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) def replace_tensor(layer: torch.nn.Module, name: str, @@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str, del new_t -def apply_marlin_linear(input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + num_bits, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) output = ops.gptq_marlin_gemm(reshaped_x, weight, weight_scale, + weight_zp, g_idx, g_idx_sort_indices, workspace, @@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, - is_k_full=is_k_full) + is_k_full=True, + has_zp=True) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 1773748a0f228..541d148c761fc 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -2,11 +2,13 @@ from typing import List -import numpy +import numpy as np import torch -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales -from .quant_utils import get_pack_factor, quantize_weights, sort_weights +from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, + marlin_zero_points) +from .quant_utils import (get_pack_factor, quantize_weights, + quantize_weights_with_zp, sort_weights) class MarlinWorkspace: @@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): pack_factor = get_pack_factor(num_bits) orig_device = q_w.device - q_w = q_w.cpu().numpy().astype(numpy.uint32) + q_w = q_w.cpu().numpy().astype(np.uint32) - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=np.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) return q_packed @@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int): for j in range(4): perm_list.extend([p + 256 * j for p in perm1]) - perm = numpy.array(perm_list) + perm = np.array(perm_list) if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) + interleave = np.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) @@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, res_list[i] = res_list[i].to(w.device) return res_list + + +def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 177cb23f63cf4..7abe919f859ca 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -106,6 +106,67 @@ def reshape_w(w): ) +def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + min_q_val = 0 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max = torch.max(w, 0, keepdim=True)[0] + min = torch.min(w, 0, keepdim=True)[0] + s = (max - min).clamp(min=1e-5) / max_q_val + + # Compute zero-point for each group + zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int() + + # Quantize + q_w = torch.round(w / s).int() + zp + q_w = torch.clamp(q_w, min_q_val, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - zp).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + zp = zp.reshape((-1, size_n)).contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + zp.to(device=orig_device), + ) + + def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device @@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): ) -def gptq_pack( +def pack_rows( q_w: torch.Tensor, num_bits: int, size_k: int, @@ -144,3 +205,90 @@ def gptq_pack( q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, size_n // pack_factor + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n)