From 9c5b95bce4203dd39197681a11485297b8f6170c Mon Sep 17 00:00:00 2001 From: luka Date: Mon, 22 Jul 2024 11:07:48 -0400 Subject: [PATCH 01/27] DRAFT dynamic azp quant kernel - failing non-deterministically --- csrc/ops.h | 2 +- .../compressed_tensors/int8_quant_kernels.cu | 76 +++++++++++++++++-- csrc/reduction_utils.cuh | 23 ++++-- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_int8_quant.py | 43 +++++++++++ 5 files changed, 134 insertions(+), 12 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 023455f8a1530..c026b738b9c69 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -149,7 +149,7 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, c10::optional const& azp); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aa9511daa2772..d4689a088381a 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -70,6 +70,64 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = 0.0f; + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = fmaxf(max_val, val); + min_val = fminf(min_val, val); + } + + // Reduce the max and min values across the block + max_val = blockReduceMax(max_val); + // if (threadIdx.x == 0 and blockIdx.x == DEBUG_TOKEN) printf("MIN:"); + min_val = blockReduceMin(min_val); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float scale_val = (max_val - min_val) / 255.0f; + auto const azp_float = roundf(min_val / scale_val + 128.0f); + auto const azp_val = static_cast(azp_float); + + // Azp was rounded, which may cause the range to be slightly off. + // Expand the range to make sure all values are representable. + auto const min_nozp = static_cast(azp_val - 128); + auto const max_nozp = static_cast(azp_val + 127); + auto no_div_0 = [&](float num, float div) { + return div == 0.0f ? scale_val : num / div; + }; + + scale_val = fmaxf(no_div_0(max_val, max_nozp), no_div_0(min_val, min_nozp)); + + // Store the scale and azp + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + auto quant_val = static_cast(roundf(val / scale_val) - azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + } // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] @@ -96,7 +154,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); @@ -107,9 +165,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 08063356012b8..b107d33a765ea 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -29,6 +29,11 @@ __inline__ __device__ T _max(T a, T b) { return max(a, b); } +template +__inline__ __device__ T _min(T a, T b) { + return min(a, b); +} + template __inline__ __device__ T _sum(T a, T b) { return a + b; @@ -51,14 +56,17 @@ __inline__ __device__ T warpReduce(T val, ReduceFnType fn) { "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); #pragma unroll - for (int mask = numLanes >> 1; mask > 0; mask >>= 1) - val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); + for (int mask = numLanes >> 1; mask > 0; mask >>= 1) { + auto const other_idx = threadIdx.x ^ mask; + auto const other_val = VLLM_SHFL_XOR_SYNC(val, mask); + val = other_idx < blockDim.x ? fn(val, other_val) : val; + } return val; } template -__inline__ __device__ T blockReduce(T val, ReduceFnType fn) { +__inline__ __device__ T blockReduce(T val, ReduceFnType fn, T init = T{}) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduce(val, fn); @@ -72,8 +80,7 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn) { __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] - : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : init; val = warpReduce(val, fn); } else { // A single warpReduce is equal to blockReduce @@ -87,6 +94,12 @@ __inline__ __device__ T blockReduceMax(T val) { return blockReduce(val, detail::_max); } +template +__inline__ __device__ T blockReduceMin(T val) { + auto const max_val = std::numeric_limits::max(); + return blockReduce(val, detail::_min, max_val); +} + template __inline__ __device__ T blockReduceSum(T val) { return blockReduce(val, detail::_sum); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e26c2e28f2ecd..bdad3e65b387b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -239,7 +239,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, Tensor!? azp) -> " "()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 0b7ed26a39e1e..62aebd71cd5bb 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -33,6 +33,49 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, assert torch.allclose(ops_out, ref_out, atol=1) # big atol to account for rounding errors +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 + + x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) + x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) + + # this is why we can't have nice things + proper_round = lambda x: torch.floor(x + 0.5) + + # calculate scale and azp, and adjust the range + scales = (x_token_max - x_token_min) / torch.tensor(255.0) + azps = proper_round(x_token_min / scales + 128.0).to(torch.int32) + min_nozp, max_nozp = azps - 128.0, azps + 127.0 + + # for all elements that are 0, make result equal to scale + no_div_by_zero = lambda x, div: torch.where(div == 0, scales, x / div) + scales = torch.max(no_div_by_zero(x_token_max, max_nozp), no_div_by_zero(x_token_min, min_nozp)) + + torch_out = (x / scales - azps).round().to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max + + ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") + scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") + azp_out = torch.empty_like(azps, dtype=torch.int32, device="cuda") + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + + if (not torch.allclose(scales_out, scales)): + print(torch.argmax(torch.abs(scales_out - scales))) + assert torch.allclose(scales_out, scales) + assert torch.allclose(azp_out, azps) + assert torch.allclose(torch_out, ops_out, + atol=1) # big atol to account for rounding errors + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) From c0916db78ce0ee80e9d1cb9c8b230195c3ecb2bf Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 23 Jul 2024 17:11:43 -0400 Subject: [PATCH 02/27] Fixed blockReduce bug! Also using round-to-even for azp --- .../compressed_tensors/int8_quant_kernels.cu | 16 +++++++++----- csrc/reduction_utils.cuh | 22 +++++++++++++++---- tests/kernels/test_int8_quant.py | 17 ++++++-------- vllm/_custom_ops.py | 17 +++++++++----- 4 files changed, 47 insertions(+), 25 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index d4689a088381a..4233721864d2c 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -77,17 +77,17 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( int const token_idx = blockIdx.x; // Scan for the min and max value for this token - float max_val = 0.0f; + float max_val = std::numeric_limits::min(); float min_val = std::numeric_limits::max(); for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto val = static_cast(input[token_idx * hidden_size + i]); - max_val = fmaxf(max_val, val); - min_val = fminf(min_val, val); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); } // Reduce the max and min values across the block max_val = blockReduceMax(max_val); - // if (threadIdx.x == 0 and blockIdx.x == DEBUG_TOKEN) printf("MIN:"); + __syncthreads(); // Make sure min doesn't mess with max shared memory min_val = blockReduceMin(min_val); __shared__ scale_type scale_sh; @@ -96,7 +96,8 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // Compute the scale and zero point and store them, only on the first thread if (threadIdx.x == 0) { float scale_val = (max_val - min_val) / 255.0f; - auto const azp_float = roundf(min_val / scale_val + 128.0f); + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(min_val / scale_val + 128.0f); auto const azp_val = static_cast(azp_float); // Azp was rounded, which may cause the range to be slightly off. @@ -107,9 +108,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( return div == 0.0f ? scale_val : num / div; }; + // TODO don't adjust scale? scale_val = fmaxf(no_div_0(max_val, max_nozp), no_div_0(min_val, min_nozp)); - // Store the scale and azp + // Store the scale and azp into shared and global scale[token_idx] = scale_sh = scale_val; azp[token_idx] = azp_sh = azp_val; } @@ -157,6 +159,8 @@ void dynamic_scaled_int8_quant( torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index b107d33a765ea..2bb0ec43559fb 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -50,6 +50,11 @@ static constexpr int _nextPow2(unsigned int num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } +template +__device__ __host__ static constexpr T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + template __inline__ __device__ T warpReduce(T val, ReduceFnType fn) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, @@ -65,6 +70,8 @@ __inline__ __device__ T warpReduce(T val, ReduceFnType fn) { return val; } +// Make sure you call __syncthreads() between different blockReduce calls, as +// they are allowed to use the same shared memory. template __inline__ __device__ T blockReduce(T val, ReduceFnType fn, T init = T{}) { static_assert(maxBlockSize <= 1024); @@ -72,7 +79,10 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn, T init = T{}) { val = warpReduce(val, fn); // Calculates max number of lanes that need to participate in the last // warpReduce - constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; + constexpr int maxActiveLanes = + ceil_div(maxBlockSize, WARP_SIZE); + // shared memory can be reused between function calls, make static + // explicitly. static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; int wid = threadIdx.x / WARP_SIZE; @@ -80,8 +90,11 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn, T init = T{}) { __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : init; - val = warpReduce(val, fn); + auto const num_sh_lanes = ceil_div(blockDim.x, WARP_SIZE); + val = threadIdx.x < num_sh_lanes ? shared[lane] : init; + if (wid == 0) { + val = warpReduce(val, fn); + } } else { // A single warpReduce is equal to blockReduce val = warpReduce(val, fn); @@ -91,7 +104,8 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn, T init = T{}) { template __inline__ __device__ T blockReduceMax(T val) { - return blockReduce(val, detail::_max); + auto const min_val = std::numeric_limits::lowest(); + return blockReduce(val, detail::_max, min_val); } template diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 62aebd71cd5bb..418affd4d04a7 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,10 +4,10 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from vllm._custom_ops import scaled_int8_quant -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, +DTYPES = [torch.float] +HIDDEN_SIZES = [16, 32, 64, 67, 128, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096, 16384] # Arbitrary values for testing SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] @@ -49,12 +49,9 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) - # this is why we can't have nice things - proper_round = lambda x: torch.floor(x + 0.5) - # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = proper_round(x_token_min / scales + 128.0).to(torch.int32) + azps = torch.round(x_token_min / scales + 128.0).to(torch.int32) min_nozp, max_nozp = azps - 128.0, azps + 127.0 # for all elements that are 0, make result equal to scale @@ -64,9 +61,9 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, torch_out = (x / scales - azps).round().to(torch.int8) assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max - ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") - scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - azp_out = torch.empty_like(azps, dtype=torch.int32, device="cuda") + ops_out = torch.empty_like(x, dtype=torch.int8) + scales_out = torch.empty_like(scales, dtype=torch.float32) + azp_out = torch.empty_like(azps, dtype=torch.int32) torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) if (not torch.allclose(scales_out, scales)): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4331db8ee4e82..6c354868b7181 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -392,15 +392,20 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( input: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + scale: Optional[torch.Tensor] = None, + azp : Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ - Quantize the input tensor to int8 and return the quantized tensor and scale. + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. @@ -408,15 +413,17 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. + assert not symmetric, "Asymmetric quantization per-tensor not supported yet." torch.ops._C.static_scaled_int8_quant(output, input, scale) - return output, scale + return output, scale, None # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) - return output, input_scales + return output, input_scales, input_azp # qqq ops From 69f9493a026d00efc00b5ee0c90981a0df0babf5 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 23 Jul 2024 17:14:33 -0400 Subject: [PATCH 03/27] Remove scale adjustment --- .../compressed_tensors/int8_quant_kernels.cu | 13 +------------ tests/kernels/test_int8_quant.py | 5 ----- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 4233721864d2c..ec6ff5bc350b0 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -95,22 +95,11 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // Compute the scale and zero point and store them, only on the first thread if (threadIdx.x == 0) { - float scale_val = (max_val - min_val) / 255.0f; + float const scale_val = (max_val - min_val) / 255.0f; // Use rounding to even (same as torch.round) auto const azp_float = std::nearbyint(min_val / scale_val + 128.0f); auto const azp_val = static_cast(azp_float); - // Azp was rounded, which may cause the range to be slightly off. - // Expand the range to make sure all values are representable. - auto const min_nozp = static_cast(azp_val - 128); - auto const max_nozp = static_cast(azp_val + 127); - auto no_div_0 = [&](float num, float div) { - return div == 0.0f ? scale_val : num / div; - }; - - // TODO don't adjust scale? - scale_val = fmaxf(no_div_0(max_val, max_nozp), no_div_0(min_val, min_nozp)); - // Store the scale and azp into shared and global scale[token_idx] = scale_sh = scale_val; azp[token_idx] = azp_sh = azp_val; diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 418affd4d04a7..3b95767fd6719 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -52,11 +52,6 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) azps = torch.round(x_token_min / scales + 128.0).to(torch.int32) - min_nozp, max_nozp = azps - 128.0, azps + 127.0 - - # for all elements that are 0, make result equal to scale - no_div_by_zero = lambda x, div: torch.where(div == 0, scales, x / div) - scales = torch.max(no_div_by_zero(x_token_max, max_nozp), no_div_by_zero(x_token_min, min_nozp)) torch_out = (x / scales - azps).round().to(torch.int8) assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max From 15e4c7232013aa7fbdb0d93a45ba5441eebfb722 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 23 Jul 2024 18:07:48 -0400 Subject: [PATCH 04/27] Fixed saturation in kernel --- .../compressed_tensors/int8_quant_kernels.cu | 42 ++++++++++++++++++- tests/kernels/test_int8_quant.py | 5 +-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index ec6ff5bc350b0..4f392b26d9237 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -24,6 +24,43 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + static const float i32_min = + static_cast(std::numeric_limits::min()); + static const float i32_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i32_min, i32_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -113,8 +150,9 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // Quantize the values for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[token_idx * hidden_size + i]); - auto quant_val = static_cast(roundf(val / scale_val) - azp_val); + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) - azp_val); out[token_idx * hidden_size + i] = quant_val; } } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 3b95767fd6719..49193c864d790 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -64,9 +64,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, if (not torch.allclose(scales_out, scales)): print(torch.argmax(torch.abs(scales_out - scales))) assert torch.allclose(scales_out, scales) - assert torch.allclose(azp_out, azps) - assert torch.allclose(torch_out, ops_out, - atol=1) # big atol to account for rounding errors + assert torch.allclose(azp_out, azps, atol=1) # azp rounding error + assert torch.allclose(torch_out, ops_out, atol=1) # azp rounding error @pytest.mark.parametrize("num_tokens", NUM_TOKENS) From 6353a8b62d0e8d29e3b13cf4cebb655d38f990d0 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 23 Jul 2024 19:01:07 -0400 Subject: [PATCH 05/27] Integer allclose comparison --- tests/kernels/test_int8_quant.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 49193c864d790..3be51ee9d19e8 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -12,6 +12,14 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] +def allclose_int(input, other, atol: int = 0, rtol: float = 1e-5): + INT_DTYPES = [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.uint16, torch.uint32, + torch.uint64] + assert input.dtype in INT_DTYPES and other.dtype in INT_DTYPES + diff = torch.abs(input.to(torch.int64) - other.to(torch.int64)) + return torch.all(diff <= atol + torch.ceil(rtol * torch.abs(other).to(torch.float32)).to(torch.int64)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -33,6 +41,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, assert torch.allclose(ops_out, ref_out, atol=1) # big atol to account for rounding errors + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -64,8 +73,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, if (not torch.allclose(scales_out, scales)): print(torch.argmax(torch.abs(scales_out - scales))) assert torch.allclose(scales_out, scales) - assert torch.allclose(azp_out, azps, atol=1) # azp rounding error - assert torch.allclose(torch_out, ops_out, atol=1) # azp rounding error + assert allclose_int(azp_out, azps, atol=1) # azp rounding error + assert allclose_int(torch_out, ops_out, atol=1) # azp rounding error @pytest.mark.parametrize("num_tokens", NUM_TOKENS) From a95790a83d227f076fcb201bee337728e8d142fb Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 24 Jul 2024 12:27:54 -0400 Subject: [PATCH 06/27] utils fix --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dbe86902853cd..5d1a289a6690a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -199,7 +199,7 @@ def apply_int8_linear( # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) return ops.cutlass_scaled_mm(x_q, weight, From 84db5cdd02adfd6d17f2cabae473f6e124ca1b71 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 24 Jul 2024 18:01:44 -0400 Subject: [PATCH 07/27] Fixed torch ref conversion --- tests/kernels/test_int8_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 3be51ee9d19e8..2d9884a3e3b2e 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -62,7 +62,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, scales = (x_token_max - x_token_min) / torch.tensor(255.0) azps = torch.round(x_token_min / scales + 128.0).to(torch.int32) - torch_out = (x / scales - azps).round().to(torch.int8) + torch_out = (x / scales - azps).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max ops_out = torch.empty_like(x, dtype=torch.int8) From d11340c6a70329f9a14e811f27e67309ffe2c56c Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 24 Jul 2024 18:21:16 -0400 Subject: [PATCH 08/27] Format --- csrc/ops.h | 3 ++- csrc/torch_bindings.cpp | 3 ++- tests/kernels/test_int8_quant.py | 19 +++++++++++++------ vllm/_custom_ops.py | 11 ++++++----- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index c026b738b9c69..ffc112302128b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -149,7 +149,8 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales, c10::optional const& azp); + torch::Tensor& scales, + c10::optional const& azp); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bdad3e65b387b..db7ddf5ec3f4d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -239,7 +239,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, Tensor!? azp) -> " + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> " "()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 2d9884a3e3b2e..f142ab1e0ec65 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -13,11 +13,15 @@ def allclose_int(input, other, atol: int = 0, rtol: float = 1e-5): - INT_DTYPES = [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.uint16, torch.uint32, - torch.uint64] + INT_DTYPES = [ + torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, + torch.uint16, torch.uint32, torch.uint64 + ] assert input.dtype in INT_DTYPES and other.dtype in INT_DTYPES diff = torch.abs(input.to(torch.int64) - other.to(torch.int64)) - return torch.all(diff <= atol + torch.ceil(rtol * torch.abs(other).to(torch.float32)).to(torch.int64)) + return torch.all( + diff <= atol + + torch.ceil(rtol * torch.abs(other).to(torch.float32)).to(torch.int64)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -53,7 +57,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, torch.cuda.manual_seed(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) @@ -62,8 +67,10 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, scales = (x_token_max - x_token_min) / torch.tensor(255.0) azps = torch.round(x_token_min / scales + 128.0).to(torch.int32) - torch_out = (x / scales - azps).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max + torch_out = (x / scales - azps).round().clamp( + int8_traits.min, int8_traits.max).to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max( + ) <= int8_traits.max ops_out = torch.empty_like(x, dtype=torch.int8) scales_out = torch.empty_like(scales, dtype=torch.float32) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6c354868b7181..42788a24df0c9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -391,10 +391,10 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp : Optional[torch.Tensor] = None, - symmetric: bool = True + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -421,7 +421,8 @@ def scaled_int8_quant( input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales, input_azp From fe914419edee93fe6cf67dfde925dc7b5e429b88 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 24 Jul 2024 19:15:47 -0400 Subject: [PATCH 09/27] Inverted azp sign to be consistent with RFC, unit tests, and compressed-tensors --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 4 ++-- tests/kernels/test_int8_quant.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 4f392b26d9237..3dd6ecf52f403 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -134,7 +134,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( if (threadIdx.x == 0) { float const scale_val = (max_val - min_val) / 255.0f; // Use rounding to even (same as torch.round) - auto const azp_float = std::nearbyint(min_val / scale_val + 128.0f); + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); auto const azp_val = static_cast(azp_float); // Store the scale and azp into shared and global @@ -152,7 +152,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = - int32_to_int8(float_to_int32_rn(val / scale_val) - azp_val); + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); out[token_idx * hidden_size + i] = quant_val; } } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index f142ab1e0ec65..7d559ec965b00 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -65,9 +65,9 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = torch.round(x_token_min / scales + 128.0).to(torch.int32) + azps = torch.round(-128.0 - x_token_min / scales).to(torch.int32) - torch_out = (x / scales - azps).round().clamp( + torch_out = (x / scales + azps).round().clamp( int8_traits.min, int8_traits.max).to(torch.int8) assert torch_out.min() >= int8_traits.min and torch_out.max( ) <= int8_traits.max From f769c9974d138bdd93cd10fcf5f89546b000ab63 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 25 Jul 2024 11:17:31 -0400 Subject: [PATCH 10/27] Fix order of rounding in test (doesn't matter for small numbers, just for consistency) --- tests/kernels/test_int8_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 7d559ec965b00..afd5785e5274b 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -67,7 +67,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, scales = (x_token_max - x_token_min) / torch.tensor(255.0) azps = torch.round(-128.0 - x_token_min / scales).to(torch.int32) - torch_out = (x / scales + azps).round().clamp( + torch_out = ((x / scales).round() + azps).clamp( int8_traits.min, int8_traits.max).to(torch.int8) assert torch_out.min() >= int8_traits.min and torch_out.max( ) <= int8_traits.max From 9e49812335a575e2959e428886ad81c60548cf1c Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 25 Jul 2024 11:32:34 -0400 Subject: [PATCH 11/27] Fewer tests --- tests/kernels/test_int8_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index afd5785e5274b..6da8e9a3c87d2 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -5,7 +5,7 @@ from vllm._custom_ops import scaled_int8_quant DTYPES = [torch.float] -HIDDEN_SIZES = [16, 32, 64, 67, 128, 768, 2048, 5120, 5137, 8192, +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096, 16384] # Arbitrary values for testing SEEDS = [0] From c1ad358b890b55346b7955015b046cd3a9953501 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 25 Jul 2024 11:34:39 -0400 Subject: [PATCH 12/27] Static per-tensor kernels added --- csrc/ops.h | 3 +- .../compressed_tensors/int8_quant_kernels.cu | 37 ++++++++++++++++--- csrc/torch_bindings.cpp | 7 ++-- tests/kernels/test_int8_quant.py | 29 +++++++++++++++ vllm/_custom_ops.py | 4 +- 5 files changed, 68 insertions(+), 12 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index ffc112302128b..b8076d88c9eeb 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -146,7 +146,8 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 3dd6ecf52f403..4e8e76bd6d264 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -77,6 +77,23 @@ __global__ void static_scaled_int8_quant_kernel( } } +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, @@ -161,10 +178,12 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -173,10 +192,18 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index db7ddf5ec3f4d..4cd51509d192d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -233,15 +233,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> " - "()"); + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 6da8e9a3c87d2..da91769428b0c 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -106,3 +106,32 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE[2:]) # Fewer scales to reduce test time +@pytest.mark.parametrize("azp", [-255, 54]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float, azp: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + out1 = ((x / scale).round() + azp).clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_argument = torch.tensor([azp], dtype=torch.int32, device="cuda") + + torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument, + azp_argument) + assert torch.allclose(out1, out2, atol=1) # atol for rounding + diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42788a24df0c9..47d3062b8aa8a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -413,8 +413,8 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - assert not symmetric, "Asymmetric quantization per-tensor not supported yet." - torch.ops._C.static_scaled_int8_quant(output, input, scale) + assert symmetric == azp is None, "azp must be only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None # dynamic-per-token quantization. From 25d0f5870d9ff6c913f1ccc9d484d748e7799fac Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 26 Jul 2024 20:09:50 -0400 Subject: [PATCH 13/27] Reduced test size, fixed custom_ops wrapper --- tests/kernels/test_int8_quant.py | 2 +- vllm/_custom_ops.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index da91769428b0c..ec5f09a12b5ba 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -112,7 +112,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("scale", SCALE[2:]) # Fewer scales to reduce test time +@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time @pytest.mark.parametrize("azp", [-255, 54]) @torch.inference_mode() def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 47d3062b8aa8a..5cee1a01f62e5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -423,7 +423,8 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) return output, input_scales, input_azp From e05068c4774866d2516607f94e13031a2fc7b0ea Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 6 Aug 2024 16:54:05 -0400 Subject: [PATCH 14/27] format --- tests/kernels/test_int8_quant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index ec5f09a12b5ba..869f1e095573f 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -134,4 +134,3 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument, azp_argument) assert torch.allclose(out1, out2, atol=1) # atol for rounding - From d02c5684c93a43831648565c8888f14affe3315c Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 27 Aug 2024 15:04:49 -0400 Subject: [PATCH 15/27] Merge fixes --- .../compressed_tensors/int8_quant_kernels.cu | 6 ++++-- tests/kernels/test_int8_quant.py | 13 +++++++------ vllm/_custom_ops.py | 4 +++- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bbbd01c3dae5e..cb7f103215332 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -150,9 +150,11 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( } // Reduce the max and min values across the block - max_val = blockReduceMax(max_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); __syncthreads(); // Make sure min doesn't mess with max shared memory - min_val = blockReduceMin(min_val); + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); __shared__ scale_type scale_sh; __shared__ azp_type azp_sh; diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 46794236ca7c9..b3f566fb389d8 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -39,7 +39,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x) + ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) torch.testing.assert_close( @@ -103,11 +103,10 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, out1 = (x / scale).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) + out2, _, _ = scaled_int8_quant(x, scale) - torch.testing.assert_close( - out1, out2, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -135,4 +134,6 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument, azp_argument) - torch.testing.assert_close(out1, out2, atol=1) # atol for rounding + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3983e3e0cc847..78825a95db65d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -438,7 +438,9 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - assert symmetric == azp is None, "azp must be only be provided for asymmetric quantization." + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None From e4dc1014d3166136e2a001e039be44fedce0761c Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 29 Aug 2024 11:57:00 -0400 Subject: [PATCH 16/27] Fix for AMD build --- .../compressed_tensors/int8_quant_kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index cb7f103215332..84f5fa2d4eba2 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,9 +14,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static const auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static const auto i8_max = static_cast(std::numeric_limits::max()); // round float dst = std::nearbyint(x); @@ -33,9 +33,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int32_t float_to_int32_rn(float x) { #ifdef USE_ROCM - static const float i32_min = + static const auto i32_min = static_cast(std::numeric_limits::min()); - static const float i32_max = + static const auto i32_max = static_cast(std::numeric_limits::max()); // round float dst = std::nearbyint(x); @@ -52,9 +52,9 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static inline __device__ int8_t int32_to_int8(int32_t x) { #ifdef USE_ROCM - static const float i8_min = + static const auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static const auto i8_max = static_cast(std::numeric_limits::max()); // saturate From 31b3e44fd9f1c98390b3f6a479a0153be926a782 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 10:59:48 -0400 Subject: [PATCH 17/27] PR comments: Python nits --- tests/kernels/test_int8_quant.py | 17 +++-------------- vllm/_custom_ops.py | 2 +- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b3f566fb389d8..b5e88272e2c28 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -12,18 +12,6 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -def allclose_int(input, other, atol: int = 0, rtol: float = 1e-5): - INT_DTYPES = [ - torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, - torch.uint16, torch.uint32, torch.uint64 - ] - assert input.dtype in INT_DTYPES and other.dtype in INT_DTYPES - diff = torch.abs(input.to(torch.int64) - other.to(torch.int64)) - return torch.all( - diff <= atol + - torch.ceil(rtol * torch.abs(other).to(torch.float32)).to(torch.int64)) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -81,8 +69,9 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, if (not torch.allclose(scales_out, scales)): print(torch.argmax(torch.abs(scales_out - scales))) torch.testing.assert_close(scales_out, scales) - assert allclose_int(azp_out, azps, atol=1) # azp rounding error - assert allclose_int(torch_out, ops_out, atol=1) # azp rounding error + torch.testing.assert_close(azp_out, azps, atol=1) # azp rounding error + torch.testing.assert_close(torch_out, ops_out, + atol=1) # azp rounding error @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 78825a95db65d..c7098c8613f98 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -433,7 +433,7 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: From 5a9762e8f94d8f68c8ab86a8ea11798582f2578f Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 11:00:11 -0400 Subject: [PATCH 18/27] PR comments: saturation code --- .../compressed_tensors/int8_quant_kernels.cu | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 84f5fa2d4eba2..d465b601f4f68 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,9 +14,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const auto i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const auto i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); // round float dst = std::nearbyint(x); @@ -33,14 +33,23 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int32_t float_to_int32_rn(float x) { #ifdef USE_ROCM - static const auto i32_min = - static_cast(std::numeric_limits::min()); - static const auto i32_max = - static_cast(std::numeric_limits::max()); + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + // round float dst = std::nearbyint(x); - // saturate - dst = std::clamp(dst, i32_min, i32_max); + + // saturate on the higher end. + if (dst >= i32_max_f) { return i32_max; } + // saturate on the lower end. + if (dst <= i32_min_f) { return i32_min; } + return static_cast(dst); #else // CUDA path @@ -52,9 +61,9 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static inline __device__ int8_t int32_to_int8(int32_t x) { #ifdef USE_ROCM - static const auto i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const auto i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); // saturate From 8aed02a8eb3c8239512626354787ae172aca8a1d Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 11:37:14 -0400 Subject: [PATCH 19/27] explicit nearest rounding mode --- .../compressed_tensors/int8_quant_kernels.cu | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index d465b601f4f68..a9d099cd7103e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "../../dispatch_utils.h" @@ -12,6 +13,18 @@ #include #endif +#define USE_ROCM + +// Explicitly set the rounding mode to nearest +template +auto __device__ to_nearest(T x) { + auto const mode = std::fegetround(); + std::fesetround(FE_TONEAREST); + auto const result = std::nearbyint(x); + std::fesetround(mode); + return result; +} + static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static constexpr auto i8_min = @@ -19,7 +32,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); // round - float dst = std::nearbyint(x); + float dst = to_nearest(x); // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -43,12 +56,16 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static constexpr auto i32_max_f = static_cast(i32_max); // round - float dst = std::nearbyint(x); + float dst = to_nearest(x); // saturate on the higher end. - if (dst >= i32_max_f) { return i32_max; } + if (dst >= i32_max_f) { + return i32_max; + } // saturate on the lower end. - if (dst <= i32_min_f) { return i32_min; } + if (dst <= i32_min_f) { + return i32_min; + } return static_cast(dst); #else From 557db879470a1e90383fa309c2bfbeab04632821 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 12:08:09 -0400 Subject: [PATCH 20/27] Added rounding mode guard --- .../compressed_tensors/int8_quant_kernels.cu | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index a9d099cd7103e..67ff1dfa1a562 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -13,17 +13,21 @@ #include #endif -#define USE_ROCM - -// Explicitly set the rounding mode to nearest -template -auto __device__ to_nearest(T x) { - auto const mode = std::fegetround(); - std::fesetround(FE_TONEAREST); - auto const result = std::nearbyint(x); - std::fesetround(mode); - return result; -} +namespace { +// RAII class to temporarily set the rounding mode and restore it at the end of +// the scope. +class rounding_mode_guard { + int old_mode; + + public: + __device__ rounding_mode_guard(int mode) { + old_mode = std::fegetround(); + std::fesetround(mode); + } + + __device__ ~rounding_mode_guard() { std::fesetround(old_mode); } +}; +} // namespace static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -32,7 +36,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); // round - float dst = to_nearest(x); + float dst = std::nearbyint(x); // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -56,7 +60,7 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static constexpr auto i32_max_f = static_cast(i32_max); // round - float dst = to_nearest(x); + float dst = std::nearbyint(x); // saturate on the higher end. if (dst >= i32_max_f) { @@ -104,6 +108,7 @@ __global__ void static_scaled_int8_quant_kernel( int const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) / scale); @@ -120,6 +125,7 @@ __global__ void static_scaled_int8_azp_quant_kernel( scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; + rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); @@ -154,6 +160,9 @@ __global__ void dynamic_scaled_int8_quant_kernel( __syncthreads(); float const tmp_scale = 127.0f / block_absmax_val; + + // Quantize the values + rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) * tmp_scale); @@ -204,6 +213,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( azp_type const azp_val = azp_sh; // Quantize the values + rounding_mode_guard guard(FE_TONEAREST); for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = From 2b24032028356a730fb0e2c19c14e59d5ebd09e1 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 13:54:35 -0400 Subject: [PATCH 21/27] Rounding mode stuff removed, added comment --- .../compressed_tensors/int8_quant_kernels.cu | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 67ff1dfa1a562..d59bcb670a4f8 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include "../../dispatch_utils.h" @@ -13,29 +12,13 @@ #include #endif -namespace { -// RAII class to temporarily set the rounding mode and restore it at the end of -// the scope. -class rounding_mode_guard { - int old_mode; - - public: - __device__ rounding_mode_guard(int mode) { - old_mode = std::fegetround(); - std::fesetround(mode); - } - - __device__ ~rounding_mode_guard() { std::fesetround(old_mode); } -}; -} // namespace - static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static constexpr auto i8_min = static_cast(std::numeric_limits::min()); static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + // rounding mode is always FE_TONEAREST on HIP float dst = std::nearbyint(x); // saturate dst = std::clamp(dst, i8_min, i8_max); @@ -59,7 +42,7 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static constexpr auto i32_max = std::numeric_limits::max(); static constexpr auto i32_max_f = static_cast(i32_max); - // round + // rounding mode is always FE_TONEAREST on HIP float dst = std::nearbyint(x); // saturate on the higher end. @@ -108,7 +91,6 @@ __global__ void static_scaled_int8_quant_kernel( int const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; - rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) / scale); @@ -125,7 +107,6 @@ __global__ void static_scaled_int8_azp_quant_kernel( scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; - rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); @@ -160,9 +141,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( __syncthreads(); float const tmp_scale = 127.0f / block_absmax_val; - - // Quantize the values - rounding_mode_guard guard(FE_TONEAREST); for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) * tmp_scale); @@ -213,7 +191,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( azp_type const azp_val = azp_sh; // Quantize the values - rounding_mode_guard guard(FE_TONEAREST); for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { auto const val = static_cast(input[token_idx * hidden_size + i]); auto const quant_val = From 5e9a0cbd70280c1d445c3a805ceaffad7970ee8e Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 16:04:36 -0400 Subject: [PATCH 22/27] Fixed test --- tests/kernels/test_int8_quant.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b5e88272e2c28..74f288d16e8b1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -30,9 +30,8 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) - torch.testing.assert_close( - ops_out, ref_out, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -69,9 +68,9 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, if (not torch.allclose(scales_out, scales)): print(torch.argmax(torch.abs(scales_out - scales))) torch.testing.assert_close(scales_out, scales) - torch.testing.assert_close(azp_out, azps, atol=1) # azp rounding error - torch.testing.assert_close(torch_out, ops_out, - atol=1) # azp rounding error + # big atol to account for rounding errors + torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) + torch.testing.assert_close(ops_out, torch_out, atol=1, rtol=0.0) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) From 65b2f9cf60aefcd6a2b10e2f571b8f757b2de3be Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 16:13:07 -0400 Subject: [PATCH 23/27] Improved nearbyint rounding comment --- .../compressed_tensors/int8_quant_kernels.cu | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index d59bcb670a4f8..aec9fa002f96e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -18,8 +18,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static_cast(std::numeric_limits::min()); static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // rounding mode is always FE_TONEAREST on HIP + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -42,7 +47,10 @@ static inline __device__ int32_t float_to_int32_rn(float x) { static constexpr auto i32_max = std::numeric_limits::max(); static constexpr auto i32_max_f = static_cast(i32_max); - // rounding mode is always FE_TONEAREST on HIP + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); // saturate on the higher end. From 45e1d9e9903423c6862ee5d9e9806e3d41533047 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Sep 2024 16:44:25 -0400 Subject: [PATCH 24/27] Added saturating cast test --- tests/kernels/test_int8_quant.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 74f288d16e8b1..e7fab363bc07b 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -125,3 +125,35 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, # big atol to account for rounding errors torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + +@pytest.mark.parametrize("is_max", [True, False]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: + # Test that the saturating cast works correctly for values near i32 max/min + + from numpy import inf, nextafter + + int32_traits = torch.iinfo(torch.int32) + val = float(int32_traits.max if is_max else int32_traits.min) + + x_vals = [[ + nextafter(val, inf), val + 1, val, val - 1, + nextafter(val, -inf) + ]] + x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") + + # The calculation in the kernel is: cast(cast(x / scale) + azp) + # where cast is a saturating cast to type T. + # Scale is set to 1.0 so that the input values are the ones that are cast. + # AZP is set to 0 to make sure the int8 saturating cast is tested as well. + scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda") + azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda") + + int8_traits = torch.iinfo(torch.int8) + val_i8 = int8_traits.max if is_max else int8_traits.min + expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") + + out = torch.empty_like(expected) + torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + torch.testing.assert_close(expected, out, atol=0, rtol=0) From 2232b6dd8ced9aa850c394e29d32cbbd253fd79c Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 11 Sep 2024 11:02:59 -0400 Subject: [PATCH 25/27] Fixed scaled_int8_quant in qqq --- vllm/model_executor/layers/quantization/qqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index c3434214a1cde..5bc3737520865 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -260,7 +260,7 @@ def apply( size_k = x_2d.shape[1] size_n = s_ch.shape[1] - x_int8, s_tok = ops.scaled_int8_quant(x_2d) + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k) From 04a539e201036f7fc4cfef43a672c94792efc8bd Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 11 Sep 2024 21:25:11 -0400 Subject: [PATCH 26/27] Fixed ops_check & azp test atol --- tests/kernels/test_int8_quant.py | 44 +++++++++++++++++++------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 6fdafadf449fe..e93cb535d715a 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -15,21 +15,26 @@ def opcheck_int8_quant_static(output, input, scale, azp=None): if azp is None: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, None)) else: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, azp)) + def opcheck_int8_quant_dynamic(output, input, symmetric=True): scale = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) if symmetric: - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, None)) else: azp = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.int32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp)) + device=input.device, + dtype=torch.int32) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -75,7 +80,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = torch.round(-128.0 - x_token_min / scales).to(torch.int32) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( + torch.int32) torch_out = ((x / scales).round() + azps).clamp( int8_traits.min, int8_traits.max).to(torch.int8) @@ -92,10 +98,12 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(scales_out, scales) # big atol to account for rounding errors torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) - torch.testing.assert_close(ops_out, torch_out, atol=1, rtol=0.0) + # if AZP is off by 1, after rounding-to-even, the output may be off by 2 + torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0) opcheck_int8_quant_dynamic(ops_out, x, False) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -110,16 +118,17 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - scale = torch.tensor([scale], dtype=torch.float32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2, _, _ = scaled_int8_quant(x, scale) + out1 = (x / scale_arg).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2, _, _ = scaled_int8_quant(x, scale_arg) # big atol to account for rounding errors torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) - opcheck_int8_quant_static(out2, x, scale) + opcheck_int8_quant_static(out2, x, scale_arg) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @@ -141,16 +150,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, out1 = ((x / scale).round() + azp).clamp(int8_traits.min, int8_traits.max).to(torch.int8) out2 = torch.empty_like(x, dtype=torch.int8) - scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") - azp_argument = torch.tensor([azp], dtype=torch.int32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") - torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument, - azp_argument) + torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) # big atol to account for rounding errors torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) - opcheck_int8_quant_static(out2, x, scale, azp) + opcheck_int8_quant_static(out2, x, scale_arg, azp_arg) @pytest.mark.parametrize("is_max", [True, False]) From a3b9f6a6f4743254c297776ebf70eca284867df7 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 12 Sep 2024 12:00:44 -0400 Subject: [PATCH 27/27] Fixed cpu bindings --- csrc/cpu/quant.cpp | 9 ++++++--- csrc/cpu/torch_bindings.cpp | 9 +++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 0cfc19097fded..2d7abe6145fee 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale) { + const torch::Tensor& scale, + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale // [..., 1] -) { + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b45da1b386b5b..ab697e3e6aef7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #ifdef __AVX512F__ // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column