From 4d27a2c4680a40fc5509a8ec006b9e2e7c40b80f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 13:10:42 -0400 Subject: [PATCH 01/73] Initial `CompressedTensors` config + Activation Quantization support for static W8A8 per tensor (#195) - Depending on how we end up parsing `ignore` and `targets` (layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PR - Initial implementation for Compressed Config support + Activation Quantization for static per tensor w8a8 - Includes fused kernels added by @varun-sundar-rabindranath ```python from vllm import LLM, SamplingParams import torch prompts = [ "Hello, my name is", "The capital of France is", "The US president is", "The future of AI is" ] sampling_params = SamplingParams(temperature=0.80, top_p=0.95) llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml") outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - Verification of the different inputs expected for `targets` and `ignore` --> use functions to parse the layer names which can be shared by sparseml and vllm; would live in compressed tensors (https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86) - Updates to further optimize fake qunat --------- Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 1 + csrc/ops.h | 5 + csrc/pybind.cpp | 9 + .../compressed_tensors/int8_quant_kernels.cu | 50 ++++++ requirements-cuda.txt | 1 + vllm/model_executor/layers/linear.py | 169 ++++++++++++++---- .../layers/quantization/__init__.py | 5 +- .../layers/quantization/aqlm.py | 1 + .../model_executor/layers/quantization/awq.py | 1 + .../compressed_tensors/__init__.py | 0 .../compressed_tensors/compressed_tensors.py | 159 ++++++++++++++++ .../compressed_tensors/cutlass_gemm.py | 91 ++++++++++ .../compressed_tensors/schemes/__init__.py | 3 + .../schemes/compressed_tensors_scheme.py | 32 ++++ .../schemes/compressed_tensors_unquantized.py | 36 ++++ .../compressed_tensors_w8a8_statictensor.py | 137 ++++++++++++++ .../model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/gptq.py | 1 + .../layers/quantization/marlin.py | 1 + .../layers/quantization/squeezellm.py | 1 + vllm/model_executor/models/llama.py | 39 ++-- vllm/worker/model_runner.py | 2 +- 22 files changed, 691 insertions(+), 54 deletions(-) create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e9262b57d0867..261e57274f8c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 03bb1e24dc68e..823dabf90c307 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -156,6 +156,11 @@ void dynamic_scaled_fp8_quant( torch::Tensor& input, torch::Tensor& scale); +void quant_per_tensor( + torch::Tensor& out, + torch::Tensor& input, + float scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 2250c7f69f0ab..13514065456ce 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -80,6 +80,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &moe_align_block_size, "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def( + "quant_per_tensor", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + float>(&quant_per_tensor), + "Per-tensor Quantization"); + + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def( diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..e1af55dc225a2 --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,50 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +namespace vllm { + +template +__global__ void quant_kernel( + const scalar_t* __restrict__ input, + int8_t* __restrict__ out, + scale_type scale, + const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void quant_per_tensor( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale, + hidden_size); + }); +} diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1bddae4c6f40f..f4c04afd55c70 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 +nvidia-cutlass == 3.5.0 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4d43ed4c5f14a..5469898972e49 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ -from abc import abstractmethod -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -30,11 +30,13 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + layer_name: Optional[str] = None, + **extra_weight_attrs) -> Dict[str, Any]: + """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -47,6 +49,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: Size of the input dim of the weight across all ranks. output_size: Size of the output dim of the weight across all ranks. params_dtype: Datatype of the parameters. + layer_name: name of the layer in the state dict. """ raise NotImplementedError @@ -56,7 +59,6 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -76,9 +78,9 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - output_size_per_partition = sum(output_partition_sizes) - weight = Parameter(torch.empty(output_size_per_partition, + layer_name: Optional[str] = None, + **extra_weight_attrs) -> Dict[str, Any]: + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -108,6 +110,7 @@ class LinearBase(torch.nn.Module): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -117,10 +120,12 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -157,15 +162,16 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [self.output_size], self.input_size, - self.output_size, self.params_dtype) + self.output_size, self.params_dtype, layer_name=self.layer_name) if bias: self.bias = Parameter( @@ -202,6 +208,7 @@ class ColumnParallelLinear(LinearBase): quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. + layer_name: name of the layer in the state dict. """ def __init__( @@ -214,6 +221,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -222,18 +230,27 @@ def __init__( # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + if output_sizes is None: output_sizes = [output_size] - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + layer=self, + layer_name=self.layer_name, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -302,13 +319,19 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, quant_config, - self.output_sizes) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -318,6 +341,19 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) is_metadata = getattr(param, "is_metadata", False) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -375,6 +411,12 @@ def weight_loader(self, shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -382,6 +424,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -408,6 +457,7 @@ class QKVParallelLinear(ColumnParallelLinear): skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -420,6 +470,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): self.hidden_size = hidden_size self.head_size = head_size @@ -440,14 +491,20 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - output_sizes = [ - self.num_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, quant_config, output_sizes) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -456,6 +513,18 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) is_metadata = getattr(param, "is_metadata", False) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) if loaded_shard_id is None: # Loaded weight is already packed. @@ -491,11 +560,14 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size elif loaded_shard_id == "k": + # shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": @@ -529,6 +601,12 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -536,6 +614,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -564,6 +649,7 @@ class RowParallelLinear(LinearBase): We skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + layer_name: name of the layer in the state dict. """ def __init__( @@ -576,9 +662,10 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -586,16 +673,16 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) - + self.quant_method.create_weights( + layer=self, + layer_name=self.layer_name, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -619,6 +706,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 70e0a7cfe3e3b..06fb0c9056230 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,7 +4,9 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -16,6 +18,7 @@ "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, + "sparseml": CompressedTensorsConfig } diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 83e24fadc1405..6edb3c3e9c63b 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -231,6 +231,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): del output_size # Unused. del input_size # Unused. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index f4fc7ce020e95..00b4a4714be16 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -86,6 +86,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000000..a61bec6e03236 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8StaticTensor, CompressedTensorsUnquantized, + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], + fake_quant: bool): + self.fake_quant = fake_quant + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float32, torch.int8] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str:Any] = dict() + ignore = config.get("ignore") + fake_quant = config.get("format") == "fakequant" + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, + ignore=ignore, + fake_quant=fake_quant) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["config.json"] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Will static vs dynamic be defined in the config? + # TODO: Expand conditions/break into separate fxs as other + # schemes are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric: + return CompressedTensorsW8A8StaticTensor( + fake_quant=self.fake_quant) + raise NotImplementedError( + "Scheme not supported. Only 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module, + layer_name: str) -> "CompressedTensorsScheme": + + if layer_name is None: + raise ValueError("layer_name must be provided for CompressedTensorsConfig") + + if layer_name in self.ignore: + return CompressedTensorsUnquantized() + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + + layer_quant_details = self.layer_quant_details.get(layer_type_name) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer_name}.") + try: + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + except NotImplementedError as e: + raise e + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create the + necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer, + layer_name=layer_name) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme associated with + the layer to apply the forward pass with the layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py new file mode 100644 index 0000000000000..1b728865641d4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -0,0 +1,91 @@ +import cutlass +from cutlass import Tensor as FakeTensor +import cutlass.epilogue + +import torch +from typing import Optional, Tuple, Dict + +from vllm.logger import init_logger + +logger = init_logger("cutlass_gemm") + +def setup_dequant_epilogue(plan : cutlass.op.Gemm, + dq: torch.Tensor, + static_scales: Optional[torch.Tensor], + activation_scales: Optional[torch.Tensor]) \ + -> Tuple[cutlass.op.Gemm, Dict]: + + if all([static_scales is None, activation_scales is None]): + return plan, None + assert static_scales is not None + + def epilog_with_scales_and_act_scales(accum, scales, act_scales): + D = accum * scales * act_scales + return D + + def epilog_with_scales(accum, scales): + D = accum * scales + return D + + epilog_tensors = {'scales': static_scales, 'D': dq} + epilogue_trace_tensors = { + "accum": + FakeTensor(element=torch.int32, + shape=dq.shape, + layout_tag=cutlass.LayoutType.RowMajor), + 'scales': + static_scales, + 'D': + dq, + } + epilog_fn = epilog_with_scales + + if activation_scales is not None: + epilog_tensors['act_scales'] = activation_scales + epilogue_trace_tensors['act_scales'] = activation_scales + epilog_fn = epilog_with_scales_and_act_scales + + plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, + epilogue_trace_tensors) + return plan, epilog_tensors + + +def cutlass_gemm_dq( + x_q: torch.Tensor, + w_q: torch.Tensor, + dtype: torch.dtype, + static_scales: torch.Tensor, + activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + + dq = torch.empty((x_q.shape[0], w_q.shape[0]), dtype=dtype, device="cuda") + + log_str = (f"cutlass_gemm_dq: \n" + f" - x_q {x_q.shape} {x_q.dtype} \n" + f" - w_q {w_q.shape} {w_q.dtype} \n" + f" - o_dq {dq.shape} {dq.dtype} \n") + logger.debug(log_str) + + plan = cutlass.op.Gemm( + element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32) + + plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, + activation_scales) + + plan.run(x_q, + w_q.t(), + dq, + dq, + alpha=1, + beta=0, + visitor_args=visitor_args, + print_module=False) + + dq = dq.view(*x_q.shape[:-1], -1) + return dq diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000000..5a32069d71e26 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,3 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_unquantized import CompressedTensorsUnquantized +from .compressed_tensors_w8a8_statictensor import CompressedTensorsW8A8StaticTensor \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000000000..1873cba9b6815 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass of different + quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where scheme-specific + dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and other parameters + relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 0000000000000..d5b582f6176a0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,36 @@ +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +import torch +from typing import List, Callable +from torch.nn import Parameter +from vllm.model_executor.utils import set_weight_attrs +import torch.nn.functional as F + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored in the CompressedTensors + config. The input and loaded weight are used in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 0000000000000..9698e97f91f46 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,137 @@ +import torch +from typing import List, Union, Tuple, Callable +from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( + cutlass_gemm_dq) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs +from torch.nn import Parameter +from vllm._C import ops + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def __init__(self, fake_quant): + self.fake_quant = fake_quant + + def _quantize(self, + x: torch.Tensor, + scales: torch.Tensor, + logical_widths: List[int], + split_dim: int = 0) -> torch.Tensor: + + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + x_q_split = x_q.split(logical_widths, dim=split_dim) + x_split = x.split(logical_widths, dim=split_dim) + + for q, dq, scale in zip(x_q_split, x_split, scales): + ops.quant_per_tensor(q, dq, scale.item()) + + return x_q + + def _quantize_single(self, x: torch.Tensor, scale: float): + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + ops.quant_per_tensor(x_q, x, scale) + return x_q + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + is_tensor_partitioned = len(output_partition_sizes) != 1 + dim = sum(output_partition_sizes) if is_tensor_partitioned else 1 + + input_scale = Parameter(torch.empty(1, + device="cuda", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cuda", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(dim, + device="cuda", + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cuda", + dtype=torch.int8), + requires_grad=False) + + if not self.fake_quant: + params_dtype = torch.int8 + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + # Register parameter with the layer; register weight loader with each parameter + set_weight_attrs(weight, {"weight_loader": weight_loader}) + set_weight_attrs(weight, + {"logical_widths": output_partition_sizes}) + + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + logical_widths = weight.logical_widths + + # Input quantize + x_q = self._quantize_single(x, act_scale[0].item()) + + # Weight quantize + # TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version + # of scales in the CPU + if self.fake_quant: + w_scales = [ + weight_scale[sum(logical_widths[:i])].item() + for i in range(len(logical_widths)) + ] + w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) + w_q = self._quantize(weight, w_scales, logical_widths) + # GEMM and dq + return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) + return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ba9f3149649c1..05347ba177072 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -79,6 +79,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index ae9f7019f0592..b58db2ae7e7f7 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -98,6 +98,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 94aba620ea083..6644d1e269ff0 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -100,6 +100,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 207dbcee8afc5..4a4627f7e8968 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -75,6 +75,7 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..b723ce43b89ed 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,6 +54,7 @@ class LlamaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -61,13 +62,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -84,6 +89,7 @@ class LlamaAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -127,16 +133,18 @@ def __init__( self.kv_scale = 1.0 self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, quant_config=quant_config, ) @@ -174,6 +182,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -193,6 +202,7 @@ def __init__( attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", @@ -205,6 +215,7 @@ def __init__( sliding_window=sliding_window, ) self.mlp = LlamaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -265,8 +276,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d0..5b6a299123965 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1101,4 +1101,4 @@ def _prepare_fake_inputs( else: prompt_tokens = [0] * seq_len fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input + return SequenceData(prompt_tokens), fake_image_input \ No newline at end of file From 92b370393c4187be8deb25ae8f0c090b2bda736f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 18:52:02 +0000 Subject: [PATCH 02/73] add get_quant method to compressed tensors config --- .../quantization/compressed_tensors/compressed_tensors.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a61bec6e03236..d9eec5da6c9fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,7 @@ import torch -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) @@ -35,6 +35,12 @@ def get_min_capability(self) -> int: def get_name(self) -> str: return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": From 2a3eb8385ebf8cd7749ffec05481919c778e7be1 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 19:06:40 +0000 Subject: [PATCH 03/73] small rebase fixed --- vllm/model_executor/layers/linear.py | 4 ++-- vllm/model_executor/layers/quantization/__init__.py | 2 +- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5469898972e49..9f7d9d77b0477 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional import torch @@ -224,7 +224,7 @@ def __init__( layer_name: Optional[str] = None ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, layer_name) self.gather_output = gather_output diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 06fb0c9056230..1607efddb657b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.gptq import GPTQConfig diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d9eec5da6c9fa..5889ce469ae29 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -147,7 +147,7 @@ def create_weights(self, layer: torch.nn.Module, layer.scheme = scheme - def apply_weights(self, + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None): From 3dd1fe8857aa518121927060054c9f3c6f8141f8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 19:25:58 +0000 Subject: [PATCH 04/73] format --- vllm/model_executor/layers/linear.py | 148 +++++++++--------- .../layers/quantization/aqlm.py | 9 +- .../model_executor/layers/quantization/awq.py | 9 +- .../compressed_tensors/compressed_tensors.py | 46 +++--- .../compressed_tensors/cutlass_gemm.py | 17 +- .../compressed_tensors_w8a8_statictensor.py | 3 +- .../layers/quantization/squeezellm.py | 9 +- 7 files changed, 126 insertions(+), 115 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9f7d9d77b0477..c6dcd48ac765a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -30,13 +30,15 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs) -> Dict[str, Any]: - """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -74,10 +76,13 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs) -> Dict[str, Any]: weight = Parameter(torch.empty(sum(output_partition_sizes), @@ -113,15 +118,13 @@ class LinearBase(torch.nn.Module): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__() # Keep input parameters @@ -154,24 +157,25 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, self.input_size, - [self.output_size], self.input_size, - self.output_size, self.params_dtype, layer_name=self.layer_name) + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + layer_name=self.layer_name) if bias: self.bias = Parameter( @@ -211,18 +215,16 @@ class ColumnParallelLinear(LinearBase): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) @@ -310,17 +312,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -460,18 +460,16 @@ class QKVParallelLinear(ColumnParallelLinear): layer_name: name of the layer in the state dict. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -652,18 +650,16 @@ class RowParallelLinear(LinearBase): layer_name: name of the layer in the state dict. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + layer_name: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, layer_name) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6edb3c3e9c63b..1215f818de90d 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -227,10 +227,13 @@ class AQLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: AQLMConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 00b4a4714be16..58e3fd0d1d844 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -82,10 +82,13 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5889ce469ae29..5b6001d79732b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,18 +35,19 @@ def get_min_capability(self) -> int: def get_name(self) -> str: return "compressed_tensors" - + def get_quant_method( - self, layer: torch.nn.Module) -> Optional["CompressedTensorsLinearMethod"]: + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) return None @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - layer_quant_details: Dict[str:Any] = dict() - ignore = config.get("ignore") - fake_quant = config.get("format") == "fakequant" + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + fake_quant: bool = config.get("format") == "fakequant" for key, quant_config in config["config_groups"].items(): targets = quant_config.get("targets") @@ -66,9 +67,7 @@ def get_config_filenames(cls) -> List[str]: return ["config.json"] def _get_schema(self, weight_quant: Dict, input_quant: Dict): - # TODO: Will static vs dynamic be defined in the config? - # TODO: Expand conditions/break into separate fxs as other - # schemes are supported + # TODO: Refactor as additional cases are supported weight_bit = weight_quant.get("num_bits") input_bit = input_quant.get("num_bits") @@ -90,11 +89,14 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): "Scheme not supported. Only 8-bit static symmtetric " "per tensor quantization is currently supported") - def get_scheme(self, layer: torch.nn.Module, - layer_name: str) -> "CompressedTensorsScheme": + def get_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> "CompressedTensorsScheme": if layer_name is None: - raise ValueError("layer_name must be provided for CompressedTensorsConfig") + raise ValueError( + "layer_name must be provided for CompressedTensorsConfig") if layer_name in self.ignore: return CompressedTensorsUnquantized() @@ -106,8 +108,11 @@ def get_scheme(self, layer: torch.nn.Module, if target.lower() in layer_name_class: layer_type_name = target break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") - layer_quant_details = self.layer_quant_details.get(layer_type_name) + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) if layer_quant_details is None: raise ValueError( f"Could not find quantization details for {layer_name}.") @@ -123,10 +128,13 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): """ @@ -146,11 +154,11 @@ def create_weights(self, layer: torch.nn.Module, weight_loader=weight_loader) layer.scheme = scheme - + def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the layer input. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index 1b728865641d4..1766aed1d6925 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -65,15 +65,14 @@ def cutlass_gemm_dq( f" - o_dq {dq.shape} {dq.dtype} \n") logger.debug(log_str) - plan = cutlass.op.Gemm( - element_A=x_q.dtype, - element_B=w_q.dtype, - element_C=dq.dtype, - element_D=dq.dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.int32) + plan = cutlass.op.Gemm(element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32) plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, activation_scales) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 9698e97f91f46..38b810f1c9ab0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -96,8 +96,7 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) # Register parameter with the layer; register weight loader with each parameter set_weight_attrs(weight, {"weight_loader": weight_loader}) - set_weight_attrs(weight, - {"logical_widths": output_partition_sizes}) + set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) layer.register_parameter("input_scale", input_scale) set_weight_attrs(input_scale, {"weight_loader": weight_loader}) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 4a4627f7e8968..6f408b491f1a3 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -71,10 +71,13 @@ class SqueezeLLMLinearMethod(QuantizeMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, + def create_weights(self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, layer_name: Optional[str] = None, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: From f2f8c5261acb51d378b007482955ef54debaf80f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 30 Apr 2024 20:04:21 +0000 Subject: [PATCH 05/73] fix mypy complaints --- .../compressed_tensors/cutlass_gemm.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index 1766aed1d6925..b3eccbdf6fec4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -3,27 +3,31 @@ import cutlass.epilogue import torch -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, Union, Any from vllm.logger import init_logger logger = init_logger("cutlass_gemm") +# type alias +TF = Union[torch.Tensor, float] + def setup_dequant_epilogue(plan : cutlass.op.Gemm, - dq: torch.Tensor, - static_scales: Optional[torch.Tensor], - activation_scales: Optional[torch.Tensor]) \ - -> Tuple[cutlass.op.Gemm, Dict]: + dq : torch.Tensor, + static_scales: Optional[TF], + activation_scales: Optional[TF]) \ + -> Tuple[cutlass.op.Gemm, Optional[Dict]]: if all([static_scales is None, activation_scales is None]): return plan, None assert static_scales is not None - def epilog_with_scales_and_act_scales(accum, scales, act_scales): + def epilog_with_scales_and_act_scales(accum: torch.Tensor, scales: TF, + act_scales: TF) -> torch.Tensor: D = accum * scales * act_scales return D - def epilog_with_scales(accum, scales): + def epilog_with_scales(accum: torch.Tensor, scales: TF) -> torch.Tensor: D = accum * scales return D @@ -38,7 +42,7 @@ def epilog_with_scales(accum, scales): 'D': dq, } - epilog_fn = epilog_with_scales + epilog_fn: Any = epilog_with_scales if activation_scales is not None: epilog_tensors['act_scales'] = activation_scales From d9d49b5224dccb16eb28628ed9fb5f95b07437cc Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 21:25:31 +0000 Subject: [PATCH 06/73] format fixes --- vllm/model_executor/layers/linear.py | 10 +++++----- .../layers/quantization/__init__.py | 2 +- .../compressed_tensors/compressed_tensors.py | 15 +++++++-------- .../compressed_tensors/cutlass_gemm.py | 6 +++--- .../compressed_tensors/schemes/__init__.py | 8 +++++--- .../schemes/compressed_tensors_scheme.py | 13 +++++++------ .../schemes/compressed_tensors_unquantized.py | 15 +++++++++------ .../compressed_tensors_w8a8_statictensor.py | 16 +++++++++------- .../layers/quantization/gptq_marlin.py | 1 + 9 files changed, 47 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c6dcd48ac765a..d155b5704d5ab 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -38,7 +38,7 @@ def create_weights(self, output_size: int, params_dtype: torch.dtype, layer_name: Optional[str] = None, - **extra_weight_attrs) -> Dict[str, Any]: + **extra_weight_attrs): """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -84,7 +84,7 @@ def create_weights(self, output_size: int, params_dtype: torch.dtype, layer_name: Optional[str] = None, - **extra_weight_attrs) -> Dict[str, Any]: + **extra_weight_attrs): weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), @@ -413,7 +413,7 @@ def weight_loader(self, param_data = param_data.narrow(0, shard_offset, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths") + logical_widths = getattr(param, "logical_widths", None) param_data, loaded_weight = param_shard_splitter( param_data, loaded_weight, loaded_shard_id, logical_widths) @@ -601,7 +601,7 @@ def weight_loader(self, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths") + logical_widths = getattr(param, "logical_widths", None) param_data, loaded_weight = param_shard_splitter( param_data, loaded_weight, loaded_shard_id, logical_widths) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 14052dc725830..73fd41d7656ec 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,9 +4,9 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5b6001d79732b..599cce689d656 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,11 +5,9 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) - from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8StaticTensor, CompressedTensorsUnquantized, - CompressedTensorsScheme) -from vllm.model_executor.utils import set_weight_attrs + CompressedTensorsScheme, CompressedTensorsUnquantized, + CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): @@ -138,8 +136,8 @@ def create_weights(self, layer_name: Optional[str] = None, **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create the - necessary parameters for the layer. + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. """ weight_loader = extra_weight_attrs.get("weight_loader") @@ -160,8 +158,9 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme associated with - the layer to apply the forward pass with the layer input. + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. """ if bias is not None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index b3eccbdf6fec4..72720a934227d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -1,9 +1,9 @@ +from typing import Any, Dict, Optional, Tuple, Union + import cutlass -from cutlass import Tensor as FakeTensor import cutlass.epilogue - import torch -from typing import Optional, Tuple, Dict, Union, Any +from cutlass import Tensor as FakeTensor from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 5a32069d71e26..831905b63e2c9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,3 +1,5 @@ -from .compressed_tensors_scheme import CompressedTensorsScheme -from .compressed_tensors_unquantized import CompressedTensorsUnquantized -from .compressed_tensors_w8a8_statictensor import CompressedTensorsW8A8StaticTensor \ No newline at end of file +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 1873cba9b6815..3a5904208656e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod + import torch __all__ = ["CompressedTensorsScheme"] @@ -6,8 +7,8 @@ class CompressedTensorsScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass of different - quantization schemes supported by CompressedTensors. + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. """ @abstractmethod @@ -21,11 +22,11 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): """ - Run the forward pass for the particular scheme. This is where scheme-specific - dequant/quant steps/kernels should be applied. + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. - :param layer: toch.nn.Module with the registered weights and other parameters - relevant to the particular scheme. + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index d5b582f6176a0..0cfac13d1ca25 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -1,18 +1,21 @@ -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) +from typing import Callable, List + import torch -from typing import List, Callable +import torch.nn.functional as F from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs -import torch.nn.functional as F __all__ = ["CompressedTensorsUnquantized"] class CompressedTensorsUnquantized(CompressedTensorsScheme): """ - Implements the scheme for all layers which are ignored in the CompressedTensors - config. The input and loaded weight are used in a linear transformation. + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. """ def create_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 38b810f1c9ab0..03252882d2ed4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -1,12 +1,14 @@ +from typing import Callable, List, Tuple, Union + import torch -from typing import List, Union, Tuple, Callable -from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( +from torch.nn import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 cutlass_gemm_dq) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs -from torch.nn import Parameter -from vllm._C import ops __all__ = ["CompressedTensorsW8A8StaticTensor"] @@ -94,7 +96,7 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - # Register parameter with the layer; register weight loader with each parameter + set_weight_attrs(weight, {"weight_loader": weight_loader}) set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) @@ -122,8 +124,8 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): x_q = self._quantize_single(x, act_scale[0].item()) # Weight quantize - # TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version - # of scales in the CPU + # TODO : try not to remove device-to-host copy. + # i.e. keep the non-duplicated version of scales in the CPU if self.fake_quant: w_scales = [ weight_scale[sum(logical_widths[:i])].item() diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efbffa0878c4b..07e57302d9a84 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -206,6 +206,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ) -> None: del output_size From c31a7af9332cd3cd78e6e7a3bc7c1c66ebc26302 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 1 May 2024 14:20:09 +0000 Subject: [PATCH 07/73] format fix post rebase --- vllm/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 767a478acbd6b..4ed9c74fc21b8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -379,7 +379,7 @@ def weight_loader(self, "We do not currently support loaded_shard_id == None and " "shard_splitter != None for a parameter. Please open an issue." ) - + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) From ca01b39af7fd6675a994100aa081d11025264dee Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 1 May 2024 15:03:44 -0400 Subject: [PATCH 08/73] lazy import CompressedTensorsW8A8StaticTensor (#220) vllm CI fixes --------- Co-authored-by: Varun Sundar Rabindranath --- .../compressed_tensors/int8_quant_kernels.cu | 10 ++++++++++ .../compressed_tensors/compressed_tensors.py | 9 ++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e1af55dc225a2..75de15e47ef0a 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -6,9 +6,19 @@ static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + float dst; + // Round to nearest even + asm volatile("v_rndne_f32 %0, %1;\n" : "=r"(dst) : "v"(x)); + // Saturate + dst = dst < -128.0f ? -128.0f : dst: + dst = dst > 127.0f ? 127.0f : dst; + return static_cast(dst); +#else uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); +#endif } namespace vllm { diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 599cce689d656..0e7230e38da7b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -6,8 +6,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsUnquantized, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsScheme, CompressedTensorsUnquantized) class CompressedTensorsConfig(QuantizationConfig): @@ -80,7 +79,11 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): is_tensor = weight_strategy == input_strategy == "tensor" is_symmetric = weight_symmetric and input_symmetric - if is_8_bits and is_tensor and is_symmetric: + if is_8_bits and is_tensor and is_symmetric and \ + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for now. + from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( # noqa: E501 + CompressedTensorsW8A8StaticTensor) return CompressedTensorsW8A8StaticTensor( fake_quant=self.fake_quant) raise NotImplementedError( From f0197d4429136d17c4c12bd0e35c32425da81515 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 1 May 2024 15:50:13 -0400 Subject: [PATCH 09/73] lazy cutlass_gemm_dq import (#221) lazy cutlass_gemm_dq import --------- Co-authored-by: Varun Sundar Rabindranath --- .../compressed_tensors/int8_quant_kernels.cu | 2 +- .../compressed_tensors/compressed_tensors.py | 10 +++++----- .../quantization/compressed_tensors/cutlass_gemm.py | 1 - .../schemes/compressed_tensors_w8a8_statictensor.py | 8 ++++++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 75de15e47ef0a..89b4a0d95cb39 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -11,7 +11,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) // Round to nearest even asm volatile("v_rndne_f32 %0, %1;\n" : "=r"(dst) : "v"(x)); // Saturate - dst = dst < -128.0f ? -128.0f : dst: + dst = dst < -128.0f ? -128.0f : dst; dst = dst > 127.0f ? 127.0f : dst; return static_cast(dst); #else diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0e7230e38da7b..366e396850525 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsUnquantized) + CompressedTensorsScheme, CompressedTensorsUnquantized, + CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): @@ -80,10 +81,9 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): is_symmetric = weight_symmetric and input_symmetric if is_8_bits and is_tensor and is_symmetric and \ - torch.cuda.is_available(): - # CompressedTensorsW8A8StaticTensor only supports CUDA path for now. - from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( # noqa: E501 - CompressedTensorsW8A8StaticTensor) + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for + # now. return CompressedTensorsW8A8StaticTensor( fake_quant=self.fake_quant) raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py index 72720a934227d..db9a63937de75 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Optional, Tuple, Union import cutlass -import cutlass.epilogue import torch from cutlass import Tensor as FakeTensor diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 03252882d2ed4..7d07b2a3b79d2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -4,8 +4,6 @@ from torch.nn import Parameter from vllm._C import ops -from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 - cutlass_gemm_dq) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs @@ -115,6 +113,12 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + + # Lazy import so we don't fail on cutlass imports on non-CUDA + # machines. + from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 + cutlass_gemm_dq) + weight = layer.weight weight_scale = layer.weight_scale act_scale = layer.input_scale From 4624b467b946db35753979441357d7e2d87957a6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 1 May 2024 21:24:49 +0000 Subject: [PATCH 10/73] fix asm --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 89b4a0d95cb39..668d165621db1 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -9,7 +9,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) #ifdef USE_ROCM float dst; // Round to nearest even - asm volatile("v_rndne_f32 %0, %1;\n" : "=r"(dst) : "v"(x)); + asm volatile("v_rndne_f32 %0, %1;" : "=v"(dst) : "v"(x)); // Saturate dst = dst < -128.0f ? -128.0f : dst; dst = dst > 127.0f ? 127.0f : dst; From 75757d54c7143964c0bf2400bd774295331b8f71 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 2 May 2024 01:12:14 +0000 Subject: [PATCH 11/73] update shape change --- vllm/model_executor/layers/linear.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4ed9c74fc21b8..c9efaf284e426 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -460,11 +460,12 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") - if len(param_data.shape) == 0: - param_data = param_data.reshape(1) + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -758,7 +759,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data, loaded_weight = fp8_scales_shard_indexer(param_data, loaded_weight, shard_id=0) - if len(loaded_weight.shape) == 0: + + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape From e1df0eb52f3a8c1fdd3a1a7033969d606687f371 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 2 May 2024 02:44:53 +0000 Subject: [PATCH 12/73] add todo --- .../schemes/compressed_tensors_w8a8_statictensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 7d07b2a3b79d2..a67ed61ff5b46 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -130,7 +130,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): # Weight quantize # TODO : try not to remove device-to-host copy. # i.e. keep the non-duplicated version of scales in the CPU - if self.fake_quant: + if self.fake_quant: # TODO: update w_scales = [ weight_scale[sum(logical_widths[:i])].item() for i in range(len(logical_widths)) From bc0991c8a7635f336323baf575f0035821c07c13 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 May 2024 21:12:03 +0000 Subject: [PATCH 13/73] Rename quant_per_tensor -> static_scaled_int8_quant --- csrc/ops.h | 2 +- csrc/pybind.cpp | 7 +------ .../quantization/compressed_tensors/int8_quant_kernels.cu | 8 ++++---- .../schemes/compressed_tensors_w8a8_statictensor.py | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index ae0c59a8efde6..b67a14617b3e6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -174,7 +174,7 @@ void dynamic_scaled_fp8_quant( torch::Tensor& input, torch::Tensor& scale); -void quant_per_tensor( +void static_scaled_int8_quant( torch::Tensor& out, torch::Tensor& input, float scale); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f597a8fc4ef54..6b2b5b7f08b49 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -83,12 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); ops.def( - "quant_per_tensor", - py::overload_cast< - torch::Tensor&, - torch::Tensor&, - float>(&quant_per_tensor), - "Per-tensor Quantization"); + "static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); // Cache ops diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 668d165621db1..bffdf70a8565e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -24,7 +24,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) namespace vllm { template -__global__ void quant_kernel( +__global__ void static_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, scale_type scale, @@ -39,7 +39,7 @@ __global__ void quant_kernel( } } // namespace vllm -void quant_per_tensor( +void static_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] float scale) { @@ -50,8 +50,8 @@ void quant_per_tensor( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { - vllm::quant_kernel<<>>( + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel<<>>( input.data_ptr(), out.data_ptr(), scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index a67ed61ff5b46..ff08c959630ae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -27,7 +27,7 @@ def _quantize(self, x_split = x.split(logical_widths, dim=split_dim) for q, dq, scale in zip(x_q_split, x_split, scales): - ops.quant_per_tensor(q, dq, scale.item()) + ops.static_scaled_int8_quant(q, dq, scale.item()) return x_q From 74ad650a05f302cd5a5bb51b30619045d20e759d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 May 2024 21:12:59 +0000 Subject: [PATCH 14/73] Remove cruft --- .../schemes/compressed_tensors_w8a8_statictensor.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index ff08c959630ae..fc6f153e66574 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -33,7 +33,7 @@ def _quantize(self, def _quantize_single(self, x: torch.Tensor, scale: float): x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") - ops.quant_per_tensor(x_q, x, scale) + ops.static_scaled_int8_quant(x_q, x, scale) return x_q def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5b6a299123965..0704f5fec54d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1101,4 +1101,4 @@ def _prepare_fake_inputs( else: prompt_tokens = [0] * seq_len fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input \ No newline at end of file + return SequenceData(prompt_tokens), fake_image_input From cf5600f5ba66305aa8ee0ff62002d9093336834f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 14 May 2024 20:31:19 +0000 Subject: [PATCH 15/73] fixes : typo --- CMakeLists.txt | 1 - vllm/model_executor/layers/quantization/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 235af2aa136b5..22dbfdc042845 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -168,7 +168,6 @@ set(VLLM_EXT_SRC "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 2e8a3c9c9c1d2..543c42232710b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -23,7 +23,7 @@ "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, - "sparseml": CompressedTensorsConfig + "sparseml": CompressedTensorsConfig, "deepspeedfp": DeepSpeedFPConfig } From 169ce7f3ec34384b6e5c8349c49b20c88c4cd19a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 15 May 2024 19:09:05 +0000 Subject: [PATCH 16/73] py-cutlass temporary hack for num_prompts==1 --- .../compressed_tensors_w8a8_statictensor.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index fc6f153e66574..42934a1342470 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -137,6 +137,23 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): ] w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) w_q = self._quantize(weight, w_scales, logical_widths) - # GEMM and dq + else: + w_q = weight # already quantized + + # Gemm and dq + # TODO (varun) : cutlass epilogues are interpreted differently when + # there is only a single prompt. + # Consider the case, M, N {*, 2560} and weight_scale is a vector of + # 2560 elements. + # differently when, + # - When M > 1 - The result is correctly multiplied column wise. + # - When M == 1, the result is not a columnwise multiplication. + # + # The following if-statement is a temporary hack and will no + # longer be needed when https://github.com/vllm-project/vllm/pull/4749 + # (CPP cutlass kernels) land. + num_prompts = x_q.shape[0] + if num_prompts == 1: + return cutlass_gemm_dq(x_q, w_q, x.dtype, None) * weight_scale * act_scale + else: return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) - return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale) From 03b53e7fba568b59cbcb3e4c8d171b1ed9c0567f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 15 May 2024 19:13:34 +0000 Subject: [PATCH 17/73] yapf --- .../schemes/compressed_tensors_w8a8_statictensor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 42934a1342470..076ea55ca65e1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -138,14 +138,14 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) w_q = self._quantize(weight, w_scales, logical_widths) else: - w_q = weight # already quantized + w_q = weight # already quantized # Gemm and dq # TODO (varun) : cutlass epilogues are interpreted differently when - # there is only a single prompt. + # there is only a single prompt. # Consider the case, M, N {*, 2560} and weight_scale is a vector of # 2560 elements. - # differently when, + # differently when, # - When M > 1 - The result is correctly multiplied column wise. # - When M == 1, the result is not a columnwise multiplication. # @@ -154,6 +154,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): # (CPP cutlass kernels) land. num_prompts = x_q.shape[0] if num_prompts == 1: - return cutlass_gemm_dq(x_q, w_q, x.dtype, None) * weight_scale * act_scale + return cutlass_gemm_dq(x_q, w_q, x.dtype, + None) * weight_scale * act_scale else: return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) From f9df31b3bba3aaab95db3e941ffdac6ded83f18c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 16 May 2024 16:11:50 +0000 Subject: [PATCH 18/73] add test_int8_quant --- tests/kernels/test_int8_quant.py | 31 +++++++++++++++++++++++++++++ vllm/model_executor/models/llama.py | 1 - 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/test_int8_quant.py diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py new file mode 100644 index 0000000000000..b9aa00ce13f56 --- /dev/null +++ b/tests/kernels/test_int8_quant.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(out2, x, scale) + assert torch.allclose(out1, out2, + atol=1) # big atol to account for rounding errors diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4065490a98402..106235cd0788b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -288,7 +288,6 @@ def __init__( cache_config=cache_config, quant_config=quant_config) for idx in range(config.num_hidden_layers) - ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From ba4b6b3ab707a94bdaabc22171429ae45c78ad2b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 17 May 2024 15:35:41 +0000 Subject: [PATCH 19/73] call cpp cutlass --- .../compressed_tensors/compressed_tensors.py | 2 +- .../compressed_tensors_w8a8_statictensor.py | 28 +++---------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 366e396850525..46212fad13571 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -25,7 +25,7 @@ def get_scaled_act_names(self) -> List[str]: return [] def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.float32, torch.int8] + return [torch.float32, torch.float16, torch.int8] # Need to figure it out def get_min_capability(self) -> int: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 076ea55ca65e1..75ba410772722 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -3,7 +3,10 @@ import torch from torch.nn import Parameter +# TODO (varun) : Unify ops and custom ops from vllm._C import ops +from vllm import _custom_ops as custom_ops + from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs @@ -113,12 +116,6 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - - # Lazy import so we don't fail on cutlass imports on non-CUDA - # machines. - from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 - cutlass_gemm_dq) - weight = layer.weight weight_scale = layer.weight_scale act_scale = layer.input_scale @@ -140,21 +137,4 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): else: w_q = weight # already quantized - # Gemm and dq - # TODO (varun) : cutlass epilogues are interpreted differently when - # there is only a single prompt. - # Consider the case, M, N {*, 2560} and weight_scale is a vector of - # 2560 elements. - # differently when, - # - When M > 1 - The result is correctly multiplied column wise. - # - When M == 1, the result is not a columnwise multiplication. - # - # The following if-statement is a temporary hack and will no - # longer be needed when https://github.com/vllm-project/vllm/pull/4749 - # (CPP cutlass kernels) land. - num_prompts = x_q.shape[0] - if num_prompts == 1: - return cutlass_gemm_dq(x_q, w_q, x.dtype, - None) * weight_scale * act_scale - else: - return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) + return custom_ops.cutlass_scaled_mm_dq(x_q, w_q.t(), act_scale, weight_scale, x.dtype) From b27f31a0f0c1f43c51f6c7ac252d338bb2d2519a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 17 May 2024 15:42:26 +0000 Subject: [PATCH 20/73] remove cutlass py interface --- requirements-cuda.txt | 1 - .../compressed_tensors/cutlass_gemm.py | 93 ------------------- 2 files changed, 94 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 7201d061ca2c2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,5 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -nvidia-cutlass == 3.5.0 vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py b/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py deleted file mode 100644 index db9a63937de75..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/cutlass_gemm.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any, Dict, Optional, Tuple, Union - -import cutlass -import torch -from cutlass import Tensor as FakeTensor - -from vllm.logger import init_logger - -logger = init_logger("cutlass_gemm") - -# type alias -TF = Union[torch.Tensor, float] - -def setup_dequant_epilogue(plan : cutlass.op.Gemm, - dq : torch.Tensor, - static_scales: Optional[TF], - activation_scales: Optional[TF]) \ - -> Tuple[cutlass.op.Gemm, Optional[Dict]]: - - if all([static_scales is None, activation_scales is None]): - return plan, None - assert static_scales is not None - - def epilog_with_scales_and_act_scales(accum: torch.Tensor, scales: TF, - act_scales: TF) -> torch.Tensor: - D = accum * scales * act_scales - return D - - def epilog_with_scales(accum: torch.Tensor, scales: TF) -> torch.Tensor: - D = accum * scales - return D - - epilog_tensors = {'scales': static_scales, 'D': dq} - epilogue_trace_tensors = { - "accum": - FakeTensor(element=torch.int32, - shape=dq.shape, - layout_tag=cutlass.LayoutType.RowMajor), - 'scales': - static_scales, - 'D': - dq, - } - epilog_fn: Any = epilog_with_scales - - if activation_scales is not None: - epilog_tensors['act_scales'] = activation_scales - epilogue_trace_tensors['act_scales'] = activation_scales - epilog_fn = epilog_with_scales_and_act_scales - - plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, - epilogue_trace_tensors) - return plan, epilog_tensors - - -def cutlass_gemm_dq( - x_q: torch.Tensor, - w_q: torch.Tensor, - dtype: torch.dtype, - static_scales: torch.Tensor, - activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - - dq = torch.empty((x_q.shape[0], w_q.shape[0]), dtype=dtype, device="cuda") - - log_str = (f"cutlass_gemm_dq: \n" - f" - x_q {x_q.shape} {x_q.dtype} \n" - f" - w_q {w_q.shape} {w_q.dtype} \n" - f" - o_dq {dq.shape} {dq.dtype} \n") - logger.debug(log_str) - - plan = cutlass.op.Gemm(element_A=x_q.dtype, - element_B=w_q.dtype, - element_C=dq.dtype, - element_D=dq.dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.int32) - - plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, - activation_scales) - - plan.run(x_q, - w_q.t(), - dq, - dq, - alpha=1, - beta=0, - visitor_args=visitor_args, - print_module=False) - - dq = dq.view(*x_q.shape[:-1], -1) - return dq From b589cdd0e833d9b41ee8a8feb585bed3853a621e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 17 May 2024 15:48:16 +0000 Subject: [PATCH 21/73] format.sh --- .../schemes/compressed_tensors_w8a8_statictensor.py | 6 +++--- vllm/model_executor/layers/quantization/deepspeedfp.py | 1 + vllm/model_executor/layers/quantization/gptq_marlin_24.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 75ba410772722..3e1c0523427de 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -4,9 +4,8 @@ from torch.nn import Parameter # TODO (varun) : Unify ops and custom ops -from vllm._C import ops from vllm import _custom_ops as custom_ops - +from vllm._C import ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs @@ -137,4 +136,5 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): else: w_q = weight # already quantized - return custom_ops.cutlass_scaled_mm_dq(x_q, w_q.t(), act_scale, weight_scale, x.dtype) + return custom_ops.cutlass_scaled_mm_dq(x_q, w_q.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 31cdffbcf0ab9..3031ef24221d9 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -95,6 +95,7 @@ def create_weights(self, input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, weight_loader=None, **extra_weight_attrs): del output_size diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 1bd6127104654..0ef8c97837e16 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -127,6 +127,7 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, + layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. From 98159cf6ca99f524a4349cf0c60bd0245a774b9b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 17 May 2024 16:26:18 +0000 Subject: [PATCH 22/73] remove fake-quant --- .../compressed_tensors/compressed_tensors.py | 14 +++-------- .../compressed_tensors_w8a8_statictensor.py | 25 +++---------------- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 46212fad13571..3dcfee8462e1b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -12,9 +12,7 @@ class CompressedTensorsConfig(QuantizationConfig): - def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], - fake_quant: bool): - self.fake_quant = fake_quant + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): self.ignore = ignore self.layer_quant_details = layer_quant_details @@ -25,7 +23,7 @@ def get_scaled_act_names(self) -> List[str]: return [] def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.float32, torch.float16, torch.int8] + return [torch.float16, torch.int8] # Need to figure it out def get_min_capability(self) -> int: @@ -45,7 +43,6 @@ def get_quant_method( def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": layer_quant_details: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) - fake_quant: bool = config.get("format") == "fakequant" for key, quant_config in config["config_groups"].items(): targets = quant_config.get("targets") @@ -56,9 +53,7 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": layer_quant_details[target]["input"] = quant_config.get( "input_activations") - return cls(layer_quant_details=layer_quant_details, - ignore=ignore, - fake_quant=fake_quant) + return cls(layer_quant_details=layer_quant_details, ignore=ignore) @classmethod def get_config_filenames(cls) -> List[str]: @@ -84,8 +79,7 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): torch.cuda.is_available(): # CompressedTensorsW8A8StaticTensor only supports CUDA path for # now. - return CompressedTensorsW8A8StaticTensor( - fake_quant=self.fake_quant) + return CompressedTensorsW8A8StaticTensor() raise NotImplementedError( "Scheme not supported. Only 8-bit static symmtetric " "per tensor quantization is currently supported") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 3e1c0523427de..1bfa98b760dd6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -15,8 +15,8 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): - def __init__(self, fake_quant): - self.fake_quant = fake_quant + def __init__(self): + pass def _quantize(self, x: torch.Tensor, @@ -86,19 +86,16 @@ def create_weights(self, layer: torch.nn.Module, dtype=torch.int8), requires_grad=False) - if not self.fake_quant: - params_dtype = torch.int8 weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, device="cuda", - dtype=params_dtype), + dtype=torch.int8), requires_grad=False) layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"weight_loader": weight_loader}) - set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) layer.register_parameter("input_scale", input_scale) set_weight_attrs(input_scale, {"weight_loader": weight_loader}) @@ -118,23 +115,9 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight weight_scale = layer.weight_scale act_scale = layer.input_scale - logical_widths = weight.logical_widths # Input quantize x_q = self._quantize_single(x, act_scale[0].item()) - # Weight quantize - # TODO : try not to remove device-to-host copy. - # i.e. keep the non-duplicated version of scales in the CPU - if self.fake_quant: # TODO: update - w_scales = [ - weight_scale[sum(logical_widths[:i])].item() - for i in range(len(logical_widths)) - ] - w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) - w_q = self._quantize(weight, w_scales, logical_widths) - else: - w_q = weight # already quantized - - return custom_ops.cutlass_scaled_mm_dq(x_q, w_q.t(), act_scale, + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From 8dbeb3142326f92a4090ad43d5a415dcfd34d6df Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 17 May 2024 20:15:59 +0000 Subject: [PATCH 23/73] add compressed tensors test --- tests/quantization/test_compressed_tensors.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/quantization/test_compressed_tensors.py diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py new file mode 100644 index 0000000000000..df9a215a37694 --- /dev/null +++ b/tests/quantization/test_compressed_tensors.py @@ -0,0 +1,39 @@ +"""Test model set-up and weight loading for sparseml-quantized models. + +Run `pytest tests/quantization/test_compressed_tensors.py`. +""" + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor, + CompressedTensorsUnquantized) + + +def test_compressed_tensors_w8a8_static_setup(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" + llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + assert isinstance(down_proj.scheme, CompressedTensorsUnquantized) + + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 + assert down_proj.weight.dtype is torch.float16 + + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 From 5eeb40a30028e55bb5980e6fdb8e03d78aba7763 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 17 May 2024 20:18:39 +0000 Subject: [PATCH 24/73] remove torch.int8 --- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 3dcfee8462e1b..6e03e190f6a68 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -23,7 +23,7 @@ def get_scaled_act_names(self) -> List[str]: return [] def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.float16, torch.int8] + return [torch.float16] # Need to figure it out def get_min_capability(self) -> int: From c55e023a6b224816b2d0ecd044f8a3f44c6b6a8a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 17 May 2024 20:24:11 +0000 Subject: [PATCH 25/73] format --- tests/quantization/test_compressed_tensors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index df9a215a37694..c99b171b5128a 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,8 +6,8 @@ import torch from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor, - CompressedTensorsUnquantized) + CompressedTensorsLinearMethod, CompressedTensorsUnquantized, + CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): From f5cbbd371d119bde39e10dde4bc1dd0bae15f69d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 20 May 2024 18:42:22 +0000 Subject: [PATCH 26/73] fix config parsing to match new model --- .../quantization/compressed_tensors/compressed_tensors.py | 1 + vllm/model_executor/model_loader/weight_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 6e03e190f6a68..fd69d4624ac2a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -42,6 +42,7 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": layer_quant_details: Dict[str, Any] = dict() + config = config["compression_config"]["quantization_config"] ignore: List[str] = config.get("ignore", None) for key, quant_config in config["config_groups"].items(): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c1abde9af7701..21e8978fadb79 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -148,7 +148,7 @@ def get_quant_config(model_config: ModelConfig, quant_config_files = [ f for f in config_files if any( - f.endswith(x) for x in possible_config_filenames) + f.split("/")[-1] == x for x in possible_config_filenames) ] if len(quant_config_files) == 0: raise ValueError( From a685957d33bad22c506f1c08f6b4b601d57494c5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 20 May 2024 19:39:57 +0000 Subject: [PATCH 27/73] revert parsing to use default pathway --- .../compressed_tensors/compressed_tensors.py | 3 +-- vllm/model_executor/model_loader/weight_utils.py | 9 ++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index fd69d4624ac2a..5110ed277faae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -42,7 +42,6 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": layer_quant_details: Dict[str, Any] = dict() - config = config["compression_config"]["quantization_config"] ignore: List[str] = config.get("ignore", None) for key, quant_config in config["config_groups"].items(): @@ -58,7 +57,7 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": @classmethod def get_config_filenames(cls) -> List[str]: - return ["config.json"] + return [] def _get_schema(self, weight_quant: Dict, input_quant: Dict): # TODO: Refactor as additional cases are supported diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 21e8978fadb79..ff828c19086e1 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -120,6 +120,13 @@ def get_quant_config(model_config: ModelConfig, # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + compression_config = getattr(model_config.hf_config, + "compression_config", None) + if compression_config is not None: + hf_quant_config = compression_config.get("quantization_config", + None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) model_name_or_path = model_config.model @@ -148,7 +155,7 @@ def get_quant_config(model_config: ModelConfig, quant_config_files = [ f for f in config_files if any( - f.split("/")[-1] == x for x in possible_config_filenames) + f.endswith(x) for x in possible_config_filenames) ] if len(quant_config_files) == 0: raise ValueError( From 4dfb37fa606088914849b70dd1534a421cbca596 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 21 May 2024 16:07:18 +0000 Subject: [PATCH 28/73] PR comments --- .../quantization/compressed_tensors/compressed_tensors.py | 8 +++----- .../schemes/compressed_tensors_w8a8_statictensor.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5110ed277faae..e6c028082d3db 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -111,11 +111,9 @@ def get_scheme( if layer_quant_details is None: raise ValueError( f"Could not find quantization details for {layer_name}.") - try: - return self._get_schema(weight_quant=layer_quant_details["weight"], - input_quant=layer_quant_details["input"]) - except NotImplementedError as e: - raise e + + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 1bfa98b760dd6..bc9b68c5c704b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -15,9 +15,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): - def __init__(self): - pass - def _quantize(self, x: torch.Tensor, scales: torch.Tensor, From de81f9e71a3a41649f4deb1509a069fda0d4e0b0 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 21 May 2024 17:14:41 +0000 Subject: [PATCH 29/73] Fix scales/zero-points device allocation --- .../compressed_tensors/compressed_tensors.py | 2 +- .../compressed_tensors_w8a8_statictensor.py | 51 ++++++++++--------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e6c028082d3db..5c79f4885fee8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -81,7 +81,7 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): # now. return CompressedTensorsW8A8StaticTensor() raise NotImplementedError( - "Scheme not supported. Only 8-bit static symmtetric " + "Scheme not supported. Only CUDA, 8-bit static symmtetric " "per tensor quantization is currently supported") def get_scheme( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index bc9b68c5c704b..3805557743078 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -15,23 +15,8 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): - def _quantize(self, - x: torch.Tensor, - scales: torch.Tensor, - logical_widths: List[int], - split_dim: int = 0) -> torch.Tensor: - - x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") - x_q_split = x_q.split(logical_widths, dim=split_dim) - x_split = x.split(logical_widths, dim=split_dim) - - for q, dq, scale in zip(x_q_split, x_split, scales): - ops.static_scaled_int8_quant(q, dq, scale.item()) - - return x_q - - def _quantize_single(self, x: torch.Tensor, scale: float): - x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + def _quantize(self, x: torch.Tensor, scale: float): + x_q = torch.empty_like(x, dtype=torch.int8) ops.static_scaled_int8_quant(x_q, x, scale) return x_q @@ -62,30 +47,46 @@ def create_weights(self, layer: torch.nn.Module, **kwargs): # TODO: remove zero_point parameters once the configs given remove them + + # Note on input/weight scales and zero_points + # + # When the scales have a single value, it is required that they be + # on the CPU for 2 reasons, + # 1. Performance: + # The cutlass interface looks at the shape of the scales and if the + # scales have a single value, it does a .item() on the tensor + # and does a scalar multiply in the epilogue. `.item()` will trigger + # a GPU-CPU copy if the tensor is on the GPU. + # 2. CUDA Graphs: + # CUDA Graphs don't support `.item()` calls on a GPU tensor. + # + # TODO: zero-points are not supported yet. But we expect a similar pattern. + is_tensor_partitioned = len(output_partition_sizes) != 1 - dim = sum(output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" input_scale = Parameter(torch.empty(1, - device="cuda", + device="cpu", dtype=torch.float32), requires_grad=False) input_zero_point = Parameter(torch.empty(1, - device="cuda", + device="cpu", dtype=torch.int8), requires_grad=False) - weight_scale = Parameter(torch.empty(dim, - device="cuda", + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, dtype=torch.float32), requires_grad=False) weight_zero_point = Parameter(torch.empty(1, - device="cuda", + device="cpu", dtype=torch.int8), requires_grad=False) weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, - device="cuda", dtype=torch.int8), requires_grad=False) @@ -114,7 +115,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): act_scale = layer.input_scale # Input quantize - x_q = self._quantize_single(x, act_scale[0].item()) + x_q = self._quantize(x, act_scale[0].item()) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From 15f1863acdc51b87bdbe557855ef147ce1a896e8 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 21 May 2024 18:30:35 +0000 Subject: [PATCH 30/73] ruff --- .../schemes/compressed_tensors_w8a8_statictensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 3805557743078..c66b1309d8e76 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -60,7 +60,8 @@ def create_weights(self, layer: torch.nn.Module, # 2. CUDA Graphs: # CUDA Graphs don't support `.item()` calls on a GPU tensor. # - # TODO: zero-points are not supported yet. But we expect a similar pattern. + # TODO: zero-points are not supported yet. But we expect a similar + # pattern. is_tensor_partitioned = len(output_partition_sizes) != 1 weight_scale_dim = sum( From bd538472a2cf30dc5e4816c2b6ac6f7ebd53d4f9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 21 May 2024 19:31:55 +0000 Subject: [PATCH 31/73] add better comments --- .../compressed_tensors_w8a8_statictensor.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index c66b1309d8e76..57a0091893618 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -53,12 +53,16 @@ def create_weights(self, layer: torch.nn.Module, # When the scales have a single value, it is required that they be # on the CPU for 2 reasons, # 1. Performance: - # The cutlass interface looks at the shape of the scales and if the - # scales have a single value, it does a .item() on the tensor - # and does a scalar multiply in the epilogue. `.item()` will trigger - # a GPU-CPU copy if the tensor is on the GPU. + # When the scales (input_scale/weight_scales) have only a single + # value, we perform a scalar broadcast of that value during the + # quant/dequant operations. The "quant" and the "gemm+dequant" + # kernels accept the Scalar by-value. These tensors are allocated + # on the CPU in order to avoid the GPU-to-CPU copy when passing + # by-value. + # # 2. CUDA Graphs: - # CUDA Graphs don't support `.item()` calls on a GPU tensor. + # CUDA Graphs don't support GPU-to-CPU copy operations during + # stream capture. # # TODO: zero-points are not supported yet. But we expect a similar # pattern. From b2926f3ea6bc043f2858789045c1d0673017ceb3 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 22 May 2024 14:28:11 +0000 Subject: [PATCH 32/73] add comment --- vllm/model_executor/layers/linear.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c9efaf284e426..1e3092e07087d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -374,6 +374,12 @@ def weight_loader(self, "We do not currently support output_dim != None and " "shard_splitter != None for a parameter. Please open an issue." ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. if loaded_shard_id is None and param_shard_splitter is not None: raise NotImplementedError( "We do not currently support loaded_shard_id == None and " @@ -556,6 +562,12 @@ def weight_loader(self, "We do not currently support output_dim != None and " "shard_splitter != None for a parameter. Please open an issue." ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. if loaded_shard_id is None and param_shard_splitter is not None: raise NotImplementedError( "We do not currently support loaded_shard_id == None and " From 18640c85c227299e426cc68362a14cab262abc19 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 22 May 2024 14:40:27 +0000 Subject: [PATCH 33/73] clang format --- .../compressed_tensors/int8_quant_kernels.cu | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bffdf70a8565e..522efe3d25de7 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -4,20 +4,19 @@ #include "../../dispatch_utils.h" -static inline __device__ int8_t float_to_int8_rn(float x) -{ +static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - float dst; - // Round to nearest even - asm volatile("v_rndne_f32 %0, %1;" : "=v"(dst) : "v"(x)); - // Saturate - dst = dst < -128.0f ? -128.0f : dst; - dst = dst > 127.0f ? 127.0f : dst; - return static_cast(dst); + float dst; + // Round to nearest even + asm volatile("v_rndne_f32 %0, %1;" : "=v"(dst) : "v"(x)); + // Saturate + dst = dst < -128.0f ? -128.0f : dst; + dst = dst > 127.0f ? 127.0f : dst; + return static_cast(dst); #else - uint32_t dst; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); - return reinterpret_cast(dst); + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); #endif } @@ -25,10 +24,8 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( - const scalar_t* __restrict__ input, - int8_t* __restrict__ out, - scale_type scale, - const int hidden_size) { + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { const int tid = threadIdx.x; const int token_idx = blockIdx.x; @@ -37,12 +34,11 @@ __global__ void static_scaled_int8_quant_kernel( float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); } } -} // namespace vllm +} // namespace vllm -void static_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - float scale) { +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { assert(input.is_contiguous()); assert(out.is_contiguous()); int hidden_size = input.size(-1); @@ -50,11 +46,11 @@ void static_scaled_int8_quant( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel<<>>( - input.data_ptr(), - out.data_ptr(), - scale, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), scale, + hidden_size); + }); } From 5c5dc84f2fc9a340d1bd34ff66ee018a25bea6e1 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 22 May 2024 14:42:07 +0000 Subject: [PATCH 34/73] clang format again --- csrc/ops.h | 7 ++----- csrc/pybind.cpp | 5 ++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 92179ee6fba4f..b839eaf0d26c8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -93,11 +93,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif - -void static_scaled_int8_quant( - torch::Tensor& out, - torch::Tensor& input, - float scale); +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + float scale); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 3d9297f91b12c..cdbec4a34d77f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,9 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); - ops.def( - "static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); - + ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, + "Compute int8 quantized tensor for given scaling factor"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); From a44b4a0f352abde0adc7241392f7ee9df3086ace Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 22 May 2024 20:20:09 +0000 Subject: [PATCH 35/73] address PR comments --- .../compressed_tensors/int8_quant_kernels.cu | 23 +++++++------------ vllm/_custom_ops.py | 18 +++++++++++++++ .../compressed_tensors_w8a8_statictensor.py | 9 +------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 522efe3d25de7..44c5509e35766 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,23 +1,16 @@ #include #include -#include #include "../../dispatch_utils.h" static inline __device__ int8_t float_to_int8_rn(float x) { -#ifdef USE_ROCM - float dst; - // Round to nearest even - asm volatile("v_rndne_f32 %0, %1;" : "=v"(dst) : "v"(x)); - // Saturate - dst = dst < -128.0f ? -128.0f : dst; - dst = dst > 127.0f ? 127.0f : dst; + static constexpr float dt_min = static_cast(std::numeric_limits::min()); + static constexpr float dt_max = static_cast(std::numeric_limits::max()); + // round + float dst = round(x); + // saturate + dst = std::clamp(dst, dt_min, dt_max); return static_cast(dst); -#else - uint32_t dst; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); - return reinterpret_cast(dst); -#endif } namespace vllm { @@ -39,8 +32,8 @@ __global__ void static_scaled_int8_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] float scale) { - assert(input.is_contiguous()); - assert(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9e7d0d96bf004..f0fab4d8aa26d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -251,6 +251,24 @@ def scaled_fp8_quant( return output, scale +# int8 +def static_scaled_int8_quant(input: torch.Tensor, + scale: float) -> torch.Tensor: + """ + Quantize the input tensor to int8 and return the quantized tensor. + + Args: + input: The input tensor to be quantized to int8. + scale: Scaling factor for the int8 quantization. + + Returns: + torch.Tensor: Output tensor in int8. + """ + q = torch.empty_like(input, dtype=torch.int8) + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 57a0091893618..d16e570d12202 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -3,9 +3,7 @@ import torch from torch.nn import Parameter -# TODO (varun) : Unify ops and custom ops from vllm import _custom_ops as custom_ops -from vllm._C import ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs @@ -15,11 +13,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): - def _quantize(self, x: torch.Tensor, scale: float): - x_q = torch.empty_like(x, dtype=torch.int8) - ops.static_scaled_int8_quant(x_q, x, scale) - return x_q - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: if isinstance(shard_id, int): return shard_id @@ -120,7 +113,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): act_scale = layer.input_scale # Input quantize - x_q = self._quantize(x, act_scale[0].item()) + x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From 6f0e6e16e8f437a098d5a71a79b93767b3fb272e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 22 May 2024 20:27:06 +0000 Subject: [PATCH 36/73] clang-format --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 44c5509e35766..bc2f01ce388f6 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -4,8 +4,10 @@ #include "../../dispatch_utils.h" static inline __device__ int8_t float_to_int8_rn(float x) { - static constexpr float dt_min = static_cast(std::numeric_limits::min()); - static constexpr float dt_max = static_cast(std::numeric_limits::max()); + static constexpr float dt_min = + static_cast(std::numeric_limits::min()); + static constexpr float dt_max = + static_cast(std::numeric_limits::max()); // round float dst = round(x); // saturate From 009045448cf80f0e2f72d7842bebeb5701670831 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 14:27:02 +0000 Subject: [PATCH 37/73] remove layer name --- vllm/model_executor/layers/linear.py | 79 +++++++------------ .../layers/quantization/aqlm.py | 10 +-- .../model_executor/layers/quantization/awq.py | 10 +-- .../compressed_tensors/compressed_tensors.py | 27 ++----- .../layers/quantization/deepspeedfp.py | 1 - .../model_executor/layers/quantization/fp8.py | 1 - .../layers/quantization/gptq.py | 1 - .../layers/quantization/gptq_marlin.py | 1 - .../layers/quantization/gptq_marlin_24.py | 1 - .../layers/quantization/marlin.py | 1 - .../layers/quantization/squeezellm.py | 10 +-- vllm/model_executor/models/llama.py | 13 +-- 12 files changed, 46 insertions(+), 109 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1e3092e07087d..3e1d7210ff733 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -30,14 +30,10 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -51,7 +47,6 @@ def create_weights(self, input_size: Size of the input dim of the weight across all ranks. output_size: Size of the output dim of the weight across all ranks. params_dtype: Datatype of the parameters. - layer_name: name of the layer in the state dict. """ raise NotImplementedError @@ -76,14 +71,10 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, @@ -115,20 +106,19 @@ class LinearBase(torch.nn.Module): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. - layer_name: name of the layer in the state dict. """ - def __init__(self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None): + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() # Keep input parameters - self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -163,19 +153,15 @@ def __init__(self, bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, layer_name) + quant_config) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - layer_name=self.layer_name) + self.quant_method.create_weights(self, self.input_size, + [self.output_size], self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( @@ -218,7 +204,6 @@ class ColumnParallelLinear(LinearBase): quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. - layer_name: name of the layer in the state dict. """ def __init__(self, @@ -229,10 +214,9 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - layer_name: Optional[str] = None): + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, layer_name) + quant_config) self.gather_output = gather_output @@ -252,7 +236,6 @@ def __init__(self, output_sizes = [output_size] self.quant_method.create_weights( layer=self, - layer_name=self.layer_name, input_size_per_partition=self.input_size, output_partition_sizes=self.output_partition_sizes, input_size=self.input_size, @@ -343,13 +326,11 @@ def __init__(self, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None): + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(layer_name=layer_name, - input_size=input_size, + super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, gather_output=gather_output, @@ -499,7 +480,6 @@ class QKVParallelLinear(ColumnParallelLinear): skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. - layer_name: name of the layer in the state dict. """ def __init__(self, @@ -510,8 +490,7 @@ def __init__(self, bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None): + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -537,8 +516,7 @@ def __init__(self, self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(layer_name=layer_name, - input_size=input_size, + super().__init__(input_size=input_size, output_size=output_size, bias=bias, gather_output=False, @@ -706,7 +684,6 @@ class RowParallelLinear(LinearBase): We skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. - layer_name: name of the layer in the state dict. """ def __init__(self, @@ -717,10 +694,9 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - layer_name: Optional[str] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, layer_name) + quant_config) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -731,7 +707,6 @@ def __init__(self, assert self.quant_method is not None self.quant_method.create_weights( layer=self, - layer_name=self.layer_name, input_size_per_partition=self.input_size_per_partition, output_partition_sizes=[self.output_size], input_size=self.input_size, diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 1215f818de90d..83e24fadc1405 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -227,14 +227,10 @@ class AQLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: AQLMConfig): self.quant_config = quant_config - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): del output_size # Unused. del input_size # Unused. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 58e3fd0d1d844..f4fc7ce020e95 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -82,14 +82,10 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5c79f4885fee8..d38d066b9ef04 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -84,17 +84,7 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict): "Scheme not supported. Only CUDA, 8-bit static symmtetric " "per tensor quantization is currently supported") - def get_scheme( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None) -> "CompressedTensorsScheme": - - if layer_name is None: - raise ValueError( - "layer_name must be provided for CompressedTensorsConfig") - - if layer_name in self.ignore: - return CompressedTensorsUnquantized() + def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": # TODO: update with matching function from `compressed_tensors` layer_type_name = None @@ -110,7 +100,7 @@ def get_scheme( layer_type_name, None) if layer_quant_details is None: raise ValueError( - f"Could not find quantization details for {layer_name}.") + f"Could not find quantization details for {layer}.") return self._get_schema(weight_quant=layer_quant_details["weight"], input_quant=layer_quant_details["input"]) @@ -121,14 +111,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ Use the CompressedTensorsScheme associated with each layer to create @@ -136,8 +122,7 @@ def create_weights(self, """ weight_loader = extra_weight_attrs.get("weight_loader") - scheme = self.quantization_config.get_scheme(layer=layer, - layer_name=layer_name) + scheme = self.quantization_config.get_scheme(layer=layer) scheme.create_weights( layer=layer, input_size_per_partition=input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 3031ef24221d9..31cdffbcf0ab9 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -95,7 +95,6 @@ def create_weights(self, input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, weight_loader=None, **extra_weight_attrs): del output_size diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e4e39c00f5f9e..ff996741c1d00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -114,7 +114,6 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, **extra_weight_attrs, ): del input_size, output_size diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index b58db2ae7e7f7..ae9f7019f0592 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -98,7 +98,6 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 4b167b0877084..4374fd98012f6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -198,7 +198,6 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, **extra_weight_attrs, ) -> None: del output_size diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 68d73d4008b6c..f5345c0443029 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -138,7 +138,6 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 6594ecab45d97..3613c9d9ecf2a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -122,7 +122,6 @@ def create_weights( input_size: int, output_size: int, params_dtype: torch.dtype, - layer_name: Optional[str] = None, **extra_weight_attrs, ): del output_size # Unused. diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 6f408b491f1a3..207dbcee8afc5 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -71,14 +71,10 @@ class SqueezeLLMLinearMethod(QuantizeMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, - layer: torch.nn.Module, + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_name: Optional[str] = None, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2351910ec10af..e1fef88cdf85f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -63,17 +63,14 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - layer_name=f"{parent_name}.gate_up_proj", input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config) - self.down_proj = RowParallelLinear( - layer_name=f"{parent_name}.down_proj", - input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -135,7 +132,6 @@ def __init__( self.kv_scale = 1.0 self.qkv_proj = QKVParallelLinear( - layer_name=f"{parent_name}.qkv_proj", hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, @@ -144,7 +140,6 @@ def __init__( quant_config=quant_config, ) self.o_proj = RowParallelLinear( - layer_name=f"{parent_name}.o_proj", input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, From 4b10fd770e2c052a79ca01f52648a9ed74aa33a5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 14:31:04 +0000 Subject: [PATCH 38/73] remove unused import --- .../quantization/compressed_tensors/compressed_tensors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d38d066b9ef04..19e464bd64325 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -6,8 +6,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsUnquantized, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): From 68a59c70e7552530b6bdc0103efa3073b9884921 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 14:48:24 +0000 Subject: [PATCH 39/73] remove parent name --- vllm/model_executor/models/llama.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e1fef88cdf85f..4b673470962b0 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,6 @@ class LlamaMLP(nn.Module): def __init__( self, - parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -87,7 +86,6 @@ class LlamaAttention(nn.Module): def __init__( self, - parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -180,7 +178,6 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, - parent_name: str, config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -201,7 +198,6 @@ def __init__( attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaAttention( - parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", @@ -215,7 +211,6 @@ def __init__( cache_config=cache_config, ) self.mlp = LlamaMLP( - parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -278,8 +273,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(parent_name=f"model.layers.{idx}", - config=config, + LlamaDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config) for idx in range(config.num_hidden_layers) From b0afe676a40577eb4bea126e8755c43364f32a5f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 22 May 2024 22:30:43 +0000 Subject: [PATCH 40/73] Fix rounding --- .../compressed_tensors/int8_quant_kernels.cu | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bc2f01ce388f6..6e00de042469e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,18 +1,25 @@ #include #include +#include #include "../../dispatch_utils.h" static inline __device__ int8_t float_to_int8_rn(float x) { - static constexpr float dt_min = +#ifdef USE_ROCM + static const float i8_min = static_cast(std::numeric_limits::min()); - static constexpr float dt_max = + static const float i8_max = static_cast(std::numeric_limits::max()); // round - float dst = round(x); + float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, dt_min, dt_max); + dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); +#else + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif } namespace vllm { From 4f4951ec4a2c65797f9d280d158b836bc306a806 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 23 May 2024 15:06:02 +0000 Subject: [PATCH 41/73] comment --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 6e00de042469e..4902e4c23434c 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -16,6 +16,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) { dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); #else + // CUDA path uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); From 869de3f859a4b97f399bcc4fa455aefca585b281 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 23 May 2024 15:26:14 +0000 Subject: [PATCH 42/73] cruft --- vllm/model_executor/layers/linear.py | 90 +++++++++++++++------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3e1d7210ff733..56d91ee0bb2d6 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -147,13 +147,14 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -206,15 +207,16 @@ class ColumnParallelLinear(LinearBase): the list would be size 3. """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -319,14 +321,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__(self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -482,15 +485,16 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__(self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -597,7 +601,6 @@ def weight_loader(self, shard_offset = 0 shard_size = self.num_heads * self.head_size elif loaded_shard_id == "k": - # shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": @@ -686,15 +689,16 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) From e68e3917e9de8223648c75f90837813da5c0e9eb Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 23 May 2024 15:28:43 +0000 Subject: [PATCH 43/73] yapf --- vllm/model_executor/layers/linear.py | 89 +++++++++++++--------------- 1 file changed, 42 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 56d91ee0bb2d6..34fbfa8e33ef9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -147,14 +147,13 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -207,16 +206,15 @@ class ColumnParallelLinear(LinearBase): the list would be size 3. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -321,15 +319,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -485,16 +482,15 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -689,16 +685,15 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) From d77cf5044fa74d419d039497ef17d6c804197356 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 18:07:16 +0000 Subject: [PATCH 44/73] remove unquantized check --- tests/quantization/test_compressed_tensors.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c99b171b5128a..b83286992da3d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,8 +6,7 @@ import torch from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsUnquantized, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): @@ -27,12 +26,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) - assert isinstance(down_proj.scheme, CompressedTensorsUnquantized) assert qkv_proj.weight.dtype is torch.int8 assert o_proj.weight.dtype is torch.int8 assert gate_up_proj.weight.dtype is torch.int8 - assert down_proj.weight.dtype is torch.float16 assert qkv_proj.weight_scale.shard_splitter is not None assert qkv_proj.weight_scale.logical_widths is not None From 51a4e594441ee8a446c17b20e958b91793c70166 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 2 May 2024 18:41:51 +0000 Subject: [PATCH 45/73] update parsing to use compressed-tensors; add dynamic per token parsing case; make ignore list handling more robust --- requirements-common.txt | 1 + .../compressed_tensors/compressed_tensors.py | 95 +++++++++++++------ .../compressed_tensors/schemes/__init__.py | 2 + .../compressed_tensors_w8a8_dynamictoken.py | 24 +++++ .../compressed_tensors/schemes/data.py | 12 +++ 5 files changed, 105 insertions(+), 29 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py diff --git a/requirements-common.txt b/requirements-common.txt index 3ea22276f63f4..259d5cb86304c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,3 +19,4 @@ lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +compressed-tensors == 0.3.2 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 19e464bd64325..15b0590d425e6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,12 +1,18 @@ from typing import Any, Dict, List, Optional import torch +from compressed_tensors.quantization.lifecycle.apply import ( + find_first_name_or_class_match) +from compressed_tensors.quantization.quant_args import QuantizationStrategy from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.data import ( + NumBits, QuantizationFields) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) + CompressedTensorsScheme, + CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): @@ -15,6 +21,25 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): self.ignore = ignore self.layer_quant_details = layer_quant_details + self.num_bits = QuantizationFields.num_bits.value + self.strategy = QuantizationFields.strategy.value + self.symmetric = QuantizationFields.symmetric.value + + llama_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj" + } + + # Update the ignore list: layer with q_proj are replaced to be qkv_proj + # drop duplicates? + for layer in self.ignore: + for k in llama_mapping: + if k in layer: + layer.replace(k, llama_mapping.get(k, k)) + def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -58,40 +83,52 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _get_schema(self, weight_quant: Dict, input_quant: Dict): - # TODO: Refactor as additional cases are supported - - weight_bit = weight_quant.get("num_bits") - input_bit = input_quant.get("num_bits") + def _is_static_tensor_w8a8(self, weight_quant: Dict, input_quant: Dict): + is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( + self.num_bits) == NumBits.EIGHT + is_tensor = weight_quant.get(self.strategy) == input_quant.get( + self.strategy) == QuantizationStrategy.TENSOR + is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( + self.symmetric) + + if is_8_bits and is_tensor and is_symmetric: + return True + return False + + def _is_dynamic_token_w8a8(self, weight_quant: Dict, input_quant: Dict): + is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( + self.num_bits) == NumBits.EIGHT + is_token = weight_quant.get(self.strategy) == input_quant.get( + self.strategy + ) == "token" # TODO: QuantizationStrategy should have token + is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( + self.symmetric) + + if is_8_bits and is_token and is_symmetric: + return True + return False - weight_strategy = weight_quant.get("strategy") - input_strategy = input_quant.get("strategy") - - weight_symmetric = weight_quant.get("symmetric") - input_symmetric = input_quant.get("symmetric") + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8StaticTensor( + fake_quant=self.fake_quant) - is_8_bits = weight_bit == input_bit == 8 - is_tensor = weight_strategy == input_strategy == "tensor" - is_symmetric = weight_symmetric and input_symmetric + elif self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8DynamicToken( + fake_quant=self.fake_quant) - if is_8_bits and is_tensor and is_symmetric and \ - torch.cuda.is_available(): - # CompressedTensorsW8A8StaticTensor only supports CUDA path for - # now. - return CompressedTensorsW8A8StaticTensor() - raise NotImplementedError( - "Scheme not supported. Only CUDA, 8-bit static symmtetric " - "per tensor quantization is currently supported") + raise NotImplementedError("Scheme not supported.") def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": - # TODO: update with matching function from `compressed_tensors` - layer_type_name = None - layer_name_class = type(layer).__name__.lower() - for target in self.layer_quant_details: - if target.lower() in layer_name_class: - layer_type_name = target - break + # TODO: update/map layer_name for llama models before + # using find_first_name_or_class_match? + layer_type_name = find_first_name_or_class_match( + name=layer_name, + module=layer, + targets=self.layer_quant_details.keys(), + check_contains=True) + if layer_type_name is None: raise ValueError(f"Could not matching target for layer {layer}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 831905b63e2c9..9a910f061f580 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,5 +1,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 + CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py new file mode 100644 index 0000000000000..18f5ec035c249 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -0,0 +1,24 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8DynamicToken"] + + +class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): + + def __init__(self, fake_quant: bool): + self.fake_quant = fake_quant + + + def create_weights(self): + pass + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor ): + pass \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py new file mode 100644 index 0000000000000..378b86f88a620 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py @@ -0,0 +1,12 @@ +from enum import Enum + +__all__ = ["QuantizationFields", "NumBits"] + +class QuantizationFields(Enum): + num_bits = "num_bits" + strategy = "strategy" + symmetric = "symmetric" + +class NumBits(Enum): + EIGHT = "8" + FOUR = "4" From 6777319bf97f12374d33eebf20d8c4152d82b79d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 2 May 2024 20:12:07 +0000 Subject: [PATCH 46/73] add dynamic quantization arg, fill out create_weights/apply --- .../compressed_tensors/compressed_tensors.py | 9 +++- .../compressed_tensors_w8a8_dynamictoken.py | 53 ++++++++++++++++--- .../compressed_tensors/schemes/data.py | 3 ++ 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 15b0590d425e6..d2e98f5120c04 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -24,6 +24,7 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): self.num_bits = QuantizationFields.num_bits.value self.strategy = QuantizationFields.strategy.value self.symmetric = QuantizationFields.symmetric.value + self.dynamic = QuantizationFields.dynamic.value llama_mapping = { "q_proj": "qkv_proj", @@ -90,8 +91,10 @@ def _is_static_tensor_w8a8(self, weight_quant: Dict, input_quant: Dict): self.strategy) == QuantizationStrategy.TENSOR is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( self.symmetric) + is_static = not weight_quant.get(self.dynamic) and not input_quant.get( + self.dynamic) - if is_8_bits and is_tensor and is_symmetric: + if is_8_bits and is_tensor and is_symmetric and is_static: return True return False @@ -103,8 +106,10 @@ def _is_dynamic_token_w8a8(self, weight_quant: Dict, input_quant: Dict): ) == "token" # TODO: QuantizationStrategy should have token is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( self.symmetric) + is_dynamic = weight_quant.get(self.dynamic) and input_quant.get( + self.dynamic) - if is_8_bits and is_token and is_symmetric: + if is_8_bits and is_token and is_symmetric and is_dynamic: return True return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index 18f5ec035c249..dcff7009d6ee4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -16,9 +16,50 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): def __init__(self, fake_quant: bool): self.fake_quant = fake_quant - - def create_weights(self): - pass - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor ): - pass \ No newline at end of file + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight_zero_point = Parameter(torch.empty(1, + device="cuda", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(sum(output_partition_sizes), + device="cuda", + dtype=torch.float32), + requires_grad=False) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + # Determine per token input scales on the fly + def _quantize_activation(self, x: torch.Tensor): + x_q = torch.empty_like(x, dtype=torch.int8) + input_scales = torch.empty() + ops.quant(x_q, x, input_scales) + return x_q, input_scales + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + + x_q, input_scales = self._quantize_activation(x) + if self.fake_quant: + pass + diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py index 378b86f88a620..ee7bc1f128c5b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py @@ -2,10 +2,13 @@ __all__ = ["QuantizationFields", "NumBits"] + class QuantizationFields(Enum): num_bits = "num_bits" strategy = "strategy" symmetric = "symmetric" + dynamic = "dynamic" + class NumBits(Enum): EIGHT = "8" From 54c797a89d367f132f9303f0fcbe3d9c9a42a702 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 2 May 2024 20:43:39 +0000 Subject: [PATCH 47/73] Add quant_per_token kernels --- csrc/ops.h | 5 ++ csrc/pybind.cpp | 8 ++ .../compressed_tensors/int8_quant_kernels.cu | 85 ++++++++++++++++++- 3 files changed, 97 insertions(+), 1 deletion(-) diff --git a/csrc/ops.h b/csrc/ops.h index b839eaf0d26c8..6351c6f4a6228 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -116,6 +116,11 @@ void moe_align_block_size(torch::Tensor topk_ids, int num_experts, int block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +void quant_per_token( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scales); + #ifndef USE_ROCM using fptr_t = uint64_t; diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cdbec4a34d77f..f2a8a96a0cd7e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -69,6 +69,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); + ops.def( + "quant_per_token", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + torch::Tensor&>(&quant_per_token), + "Per-Token Quantization"); + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 4902e4c23434c..9d568423b0045 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -25,6 +25,33 @@ static inline __device__ int8_t float_to_int8_rn(float x) { namespace vllm { +// TODO (varun) : Merge this into reduction utils and use the existing interface +// TODO (varun) : Add unit tests for this +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} + template __global__ void static_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, @@ -37,7 +64,43 @@ __global__ void static_scaled_int8_quant_kernel( float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); } } -} // namespace vllm + +template +__global__ void quant_per_token_kernel( + const scalar_t* __restrict__ input, + int8_t* __restrict__ out, + scale_type scale, + const int hidden_size) { + + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + float amax_val = 0.0f; + const float zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = (float)input[token_idx * hidden_size + i]; + val = val > zero ? val : -val; + if (val > amax_val) + amax_val = val; + } + + __shared__ float s_amax; + const float block_amax_val = blockReduceMax(amax_val); + if (tid == 0) { + s_amax = block_amax_val; + scale[token_idx] = block_amax_val / 127.0f; + } + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale); + } +} + +} // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] @@ -57,3 +120,23 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] hidden_size); }); } + +void quant_per_token( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scales) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_per_token_kernel", [&] { + vllm::quant_per_token_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scales.data_ptr(), + hidden_size); + }); +} From 6bcab229ed512202f3a0c8a00af24011143f83a9 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 3 May 2024 22:13:53 +0000 Subject: [PATCH 48/73] make changes to config parsing based on sparseml updates; test dynamic per token model --- .../compressed_tensors/compressed_tensors.py | 29 ++++--- .../compressed_tensors/{schemes => }/data.py | 7 +- .../compressed_tensors_w8a8_dynamictoken.py | 79 ++++++++++++++++--- 3 files changed, 86 insertions(+), 29 deletions(-) rename vllm/model_executor/layers/quantization/compressed_tensors/{schemes => }/data.py (69%) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d2e98f5120c04..a206fe33b6fbf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,15 +1,15 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional import torch -from compressed_tensors.quantization.lifecycle.apply import ( - find_first_name_or_class_match) +#from compressed_tensors.quantization.lifecycle.apply import ( +# find_first_name_or_class_match) # TODO: needed from compressed_tensors.quantization.quant_args import QuantizationStrategy from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.data import ( - NumBits, QuantizationFields) + QuantizationFields) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) @@ -66,6 +66,8 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + config = config["compression_config"]["quantization_config"] + layer_quant_details: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) @@ -86,9 +88,9 @@ def get_config_filenames(cls) -> List[str]: def _is_static_tensor_w8a8(self, weight_quant: Dict, input_quant: Dict): is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( - self.num_bits) == NumBits.EIGHT + self.num_bits) == 8 is_tensor = weight_quant.get(self.strategy) == input_quant.get( - self.strategy) == QuantizationStrategy.TENSOR + self.strategy) == QuantizationStrategy.TENSOR.value is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( self.symmetric) is_static = not weight_quant.get(self.dynamic) and not input_quant.get( @@ -100,16 +102,17 @@ def _is_static_tensor_w8a8(self, weight_quant: Dict, input_quant: Dict): def _is_dynamic_token_w8a8(self, weight_quant: Dict, input_quant: Dict): is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( - self.num_bits) == NumBits.EIGHT - is_token = weight_quant.get(self.strategy) == input_quant.get( - self.strategy - ) == "token" # TODO: QuantizationStrategy should have token + self.num_bits) == 8 + is_token_tensor = (weight_quant.get(self.strategy) + == QuantizationStrategy.TENSOR.value) and ( + input_quant.get(self.strategy) == "token" + ) # TODO: QuantizationStrategy should have token is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( self.symmetric) - is_dynamic = weight_quant.get(self.dynamic) and input_quant.get( + is_dynamic = not weight_quant.get(self.dynamic) and input_quant.get( self.dynamic) - if is_8_bits and is_token and is_symmetric and is_dynamic: + if is_8_bits and is_token_tensor and is_symmetric and is_dynamic: return True return False @@ -128,7 +131,7 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": # TODO: update/map layer_name for llama models before # using find_first_name_or_class_match? - layer_type_name = find_first_name_or_class_match( + layer_type_name = self.find_first_name_or_class_match( name=layer_name, module=layer, targets=self.layer_quant_details.keys(), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py b/vllm/model_executor/layers/quantization/compressed_tensors/data.py similarity index 69% rename from vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py rename to vllm/model_executor/layers/quantization/compressed_tensors/data.py index ee7bc1f128c5b..97c7428e0f644 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/data.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/data.py @@ -7,9 +7,4 @@ class QuantizationFields(Enum): num_bits = "num_bits" strategy = "strategy" symmetric = "symmetric" - dynamic = "dynamic" - - -class NumBits(Enum): - EIGHT = "8" - FOUR = "4" + dynamic = "dynamic" \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index dcff7009d6ee4..39a54170a35fd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -16,34 +16,64 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): def __init__(self, fake_quant: bool): self.fake_quant = fake_quant + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + is_tensor_partitioned = len(output_partition_sizes) != 1 + dim = sum(output_partition_sizes) if is_tensor_partitioned else 1 weight_zero_point = Parameter(torch.empty(1, device="cuda", dtype=torch.int8), requires_grad=False) - weight_scale = Parameter(torch.empty(sum(output_partition_sizes), + weight_scale = Parameter(torch.empty(dim, device="cuda", dtype=torch.float32), requires_grad=False) + if not self.fake_quant: + params_dtype = torch.int8 weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, device="cuda", dtype=params_dtype), requires_grad=False) - + layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"weight_loader": weight_loader}) layer.register_parameter("weight_scale", weight_scale) set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) layer.register_parameter("weight_zero_point", weight_zero_point) set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) @@ -51,15 +81,44 @@ def create_weights(self, layer: torch.nn.Module, # Determine per token input scales on the fly def _quantize_activation(self, x: torch.Tensor): x_q = torch.empty_like(x, dtype=torch.int8) - input_scales = torch.empty() - ops.quant(x_q, x, input_scales) + input_scales = torch.empty(x.numel() // x.shape[-1], + dtype=x.dtype, + device=x.device) + ops.quant_per_token(x_q, x, input_scales) return x_q, input_scales + def _quantize(self, + x: torch.Tensor, + scales: torch.Tensor, + logical_widths: List[int], + split_dim: int = 0) -> torch.Tensor: + + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + x_q_split = x_q.split(logical_widths, dim=split_dim) + x_split = x.split(logical_widths, dim=split_dim) + + for q, dq, scale in zip(x_q_split, x_split, scales): + ops.quant_per_tensor(q, dq, scale.item()) + + return x_q + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight weight_scale = layer.weight_scale + logical_widths = weight.logical_widths + + from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 + cutlass_gemm_dq) x_q, input_scales = self._quantize_activation(x) if self.fake_quant: - pass - + w_scales = [ + weight_scale[sum(logical_widths[:i])].item() + for i in range(len(logical_widths)) + ] + w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) + w_q = self._quantize(weight, w_scales, logical_widths) + return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, + input_scales) + return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, + input_scales) From ece93e18a67a1210f6bd33fefdc962e030850532 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 6 May 2024 18:23:44 +0000 Subject: [PATCH 49/73] fix shape for cutlass issues --- .../compressed_tensors/compressed_tensors.py | 2 +- .../quantization/compressed_tensors/data.py | 2 +- .../compressed_tensors_w8a8_dynamictoken.py | 33 ++++++++++--------- .../compressed_tensors_w8a8_statictensor.py | 24 ++++++++++++++ 4 files changed, 43 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a206fe33b6fbf..7cda548619758 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,7 +35,6 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): } # Update the ignore list: layer with q_proj are replaced to be qkv_proj - # drop duplicates? for layer in self.ignore: for k in llama_mapping: if k in layer: @@ -66,6 +65,7 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + config = config["compression_config"]["quantization_config"] layer_quant_details: Dict[str, Any] = dict() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/data.py b/vllm/model_executor/layers/quantization/compressed_tensors/data.py index 97c7428e0f644..3c5af69ef2186 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/data.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/data.py @@ -1,6 +1,6 @@ from enum import Enum -__all__ = ["QuantizationFields", "NumBits"] +__all__ = ["QuantizationFields"] class QuantizationFields(Enum): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index 39a54170a35fd..b87c2f0631201 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -66,6 +66,7 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"weight_loader": weight_loader}) + set_weight_attrs(weight, {"logical_widths": output_partition_sizes}) layer.register_parameter("weight_scale", weight_scale) set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) @@ -78,20 +79,11 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_zero_point", weight_zero_point) set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) - # Determine per token input scales on the fly - def _quantize_activation(self, x: torch.Tensor): - x_q = torch.empty_like(x, dtype=torch.int8) - input_scales = torch.empty(x.numel() // x.shape[-1], - dtype=x.dtype, - device=x.device) - ops.quant_per_token(x_q, x, input_scales) - return x_q, input_scales - - def _quantize(self, - x: torch.Tensor, - scales: torch.Tensor, - logical_widths: List[int], - split_dim: int = 0) -> torch.Tensor: + def _quantize_weights(self, + x: torch.Tensor, + scales: torch.Tensor, + logical_widths: List[int], + split_dim: int = 0) -> torch.Tensor: x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") x_q_split = x_q.split(logical_widths, dim=split_dim) @@ -102,22 +94,31 @@ def _quantize(self, return x_q + # Determine per token input scales on the fly + def _quantize_activation(self, x: torch.Tensor): + x_q = torch.empty_like(x, dtype=torch.int8) + input_scales = torch.empty((x.numel() // x.shape[-1], 1), + dtype=x.dtype, + device=x.device) + ops.quant_per_token(x_q, x, input_scales) + return x_q, input_scales + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight weight_scale = layer.weight_scale - logical_widths = weight.logical_widths from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 cutlass_gemm_dq) x_q, input_scales = self._quantize_activation(x) if self.fake_quant: + logical_widths = weight.logical_widths w_scales = [ weight_scale[sum(logical_widths[:i])].item() for i in range(len(logical_widths)) ] w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) - w_q = self._quantize(weight, w_scales, logical_widths) + w_q = self._quantize_weights(weight, w_scales, logical_widths) return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, input_scales) return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index d16e570d12202..f639564fa7967 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -108,12 +108,36 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): +<<<<<<< HEAD +======= + # Lazy import so we don't fail on cutlass imports on non-CUDA + # machines. + from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 + cutlass_gemm_dq) + +>>>>>>> d8155e1e (fix shape for cutlass issues) weight = layer.weight weight_scale = layer.weight_scale act_scale = layer.input_scale # Input quantize +<<<<<<< HEAD x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) +======= + x_q = self._quantize_activation(x, act_scale[0].item()) + + if self.fake_quant: + logical_widths = weight.logical_widths + w_scales = [ + weight_scale[sum(logical_widths[:i])].item() + for i in range(len(logical_widths)) + ] + w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) + w_q = self._quantize_weights(weight, w_scales, logical_widths) + # GEMM and dq + return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) + return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale) +>>>>>>> d8155e1e (fix shape for cutlass issues) From 1d87a999a1142df6957d09847b6f4bc4574412f8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 6 May 2024 19:44:50 +0000 Subject: [PATCH 50/73] remove dicts; use quantization args directly --- .../compressed_tensors/compressed_tensors.py | 55 ++++++++----------- .../quantization/compressed_tensors/data.py | 10 ---- 2 files changed, 24 insertions(+), 41 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/data.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 7cda548619758..dba43de3cd636 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -3,13 +3,13 @@ import torch #from compressed_tensors.quantization.lifecycle.apply import ( # find_first_name_or_class_match) # TODO: needed -from compressed_tensors.quantization.quant_args import QuantizationStrategy +from compressed_tensors.quantization.quant_args import (QuantizationArgs, + QuantizationStrategy) +from pydantic import BaseModel from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.data import ( - QuantizationFields) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) @@ -21,11 +21,6 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): self.ignore = ignore self.layer_quant_details = layer_quant_details - self.num_bits = QuantizationFields.num_bits.value - self.strategy = QuantizationFields.strategy.value - self.symmetric = QuantizationFields.symmetric.value - self.dynamic = QuantizationFields.dynamic.value - llama_mapping = { "q_proj": "qkv_proj", "k_proj": "qkv_proj", @@ -75,10 +70,12 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": targets = quant_config.get("targets") for target in targets: layer_quant_details[target] = {} - layer_quant_details[target]["weight"] = quant_config.get( - "weights") - layer_quant_details[target]["input"] = quant_config.get( - "input_activations") + layer_quant_details[target][ + "weight"] = QuantizationArgs.parse_obj( + quant_config.get("weights")) + layer_quant_details[target][ + "input"] = QuantizationArgs.parse_obj( + quant_config.get("input_activations")) return cls(layer_quant_details=layer_quant_details, ignore=ignore) @@ -86,37 +83,33 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _is_static_tensor_w8a8(self, weight_quant: Dict, input_quant: Dict): - is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( - self.num_bits) == 8 - is_tensor = weight_quant.get(self.strategy) == input_quant.get( - self.strategy) == QuantizationStrategy.TENSOR.value - is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( - self.symmetric) - is_static = not weight_quant.get(self.dynamic) and not input_quant.get( - self.dynamic) + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel): + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + is_tensor = (weight_quant.strategy == input_quant.strategy == + QuantizationStrategy.TENSOR.value) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + is_static = not weight_quant.dynamic and not input_quant.dynamic if is_8_bits and is_tensor and is_symmetric and is_static: return True return False - def _is_dynamic_token_w8a8(self, weight_quant: Dict, input_quant: Dict): - is_8_bits = weight_quant.get(self.num_bits) == input_quant.get( - self.num_bits) == 8 - is_token_tensor = (weight_quant.get(self.strategy) + def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel): + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + is_token_tensor = (weight_quant.strategy == QuantizationStrategy.TENSOR.value) and ( - input_quant.get(self.strategy) == "token" + input_quant.strategy == "token" ) # TODO: QuantizationStrategy should have token - is_symmetric = weight_quant.get(self.symmetric) and input_quant.get( - self.symmetric) - is_dynamic = not weight_quant.get(self.dynamic) and input_quant.get( - self.dynamic) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + is_dynamic = not weight_quant.dynamic and input_quant.dynamic if is_8_bits and is_token_tensor and is_symmetric and is_dynamic: return True return False - def _get_schema(self, weight_quant: Dict, input_quant: Dict): + def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel): if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8StaticTensor( fake_quant=self.fake_quant) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/data.py b/vllm/model_executor/layers/quantization/compressed_tensors/data.py deleted file mode 100644 index 3c5af69ef2186..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/data.py +++ /dev/null @@ -1,10 +0,0 @@ -from enum import Enum - -__all__ = ["QuantizationFields"] - - -class QuantizationFields(Enum): - num_bits = "num_bits" - strategy = "strategy" - symmetric = "symmetric" - dynamic = "dynamic" \ No newline at end of file From 3dd1b5f6d2948b67875afe5dcb493699a0190345 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 7 May 2024 19:21:03 +0000 Subject: [PATCH 51/73] update compressed-tensors; add docstring --- requirements-common.txt | 2 +- .../compressed_tensors/compressed_tensors.py | 27 ++++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 259d5cb86304c..8675f2052bfea 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,4 +19,4 @@ lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -compressed-tensors == 0.3.2 +compressed-tensors == 0.3.3 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index dba43de3cd636..4f4689f5eeb65 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, List, Optional import torch -#from compressed_tensors.quantization.lifecycle.apply import ( -# find_first_name_or_class_match) # TODO: needed +from compressed_tensors.quantization.lifecycle.apply import ( + find_first_name_or_class_match) from compressed_tensors.quantization.quant_args import (QuantizationArgs, QuantizationStrategy) from pydantic import BaseModel @@ -29,7 +29,8 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): "up_proj": "gate_up_proj" } - # Update the ignore list: layer with q_proj are replaced to be qkv_proj + # Update the ignore list: e.g layers with q_proj are replaced + # to be qkv_proj to be compatible with vllm for layer in self.ignore: for k in llama_mapping: if k in layer: @@ -84,7 +85,7 @@ def get_config_filenames(cls) -> List[str]: return [] def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel): + input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_tensor = (weight_quant.strategy == input_quant.strategy == QuantizationStrategy.TENSOR.value) @@ -96,12 +97,12 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel, return False def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel): + input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_token_tensor = (weight_quant.strategy == QuantizationStrategy.TENSOR.value) and ( - input_quant.strategy == "token" - ) # TODO: QuantizationStrategy should have token + input_quant.strategy + == QuantizationStrategy.TOKEN.value) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_dynamic = not weight_quant.dynamic and input_quant.dynamic @@ -109,7 +110,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, return True return False - def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel): + def _get_schema(self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8StaticTensor( fake_quant=self.fake_quant) @@ -124,7 +126,7 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": # TODO: update/map layer_name for llama models before # using find_first_name_or_class_match? - layer_type_name = self.find_first_name_or_class_match( + layer_type_name = find_first_name_or_class_match( name=layer_name, module=layer, targets=self.layer_quant_details.keys(), @@ -155,7 +157,8 @@ def create_weights(self, layer: torch.nn.Module, **extra_weight_attrs): """ Use the CompressedTensorsScheme associated with each layer to create - the necessary parameters for the layer. + the necessary parameters for the layer. See LinearMethodBase for param + details """ weight_loader = extra_weight_attrs.get("weight_loader") @@ -177,7 +180,7 @@ def apply(self, """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the - layer input. + layer input. See LinearMethodBase for param details """ if bias is not None: From fed7cddc062fd7d831217d05cdc24b82435af06e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 13 May 2024 10:11:00 -0400 Subject: [PATCH 52/73] Dyn per token varun cleanup (#227) Description: - Remove inline asm for float to int8 conversion - Refactor reduction utils to add blockReduceMax --------- Co-authored-by: Varun Sundar Rabindranath --- csrc/ops.h | 3 +- csrc/pybind.cpp | 8 +-- .../compressed_tensors/int8_quant_kernels.cu | 36 ++-------- csrc/reduction_utils.cuh | 66 +++++++++++++------ .../compressed_tensors_w8a8_dynamictoken.py | 2 +- 5 files changed, 55 insertions(+), 60 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 6351c6f4a6228..883cf138def83 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -116,7 +116,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int num_experts, int block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -void quant_per_token( + +void dynamic_scaled_int8_quant( torch::Tensor& out, torch::Tensor& input, torch::Tensor& scales); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f2a8a96a0cd7e..05aac079ccc46 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -70,13 +70,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); ops.def( - "quant_per_token", - py::overload_cast< - torch::Tensor&, - torch::Tensor&, - torch::Tensor&>(&quant_per_token), - "Per-Token Quantization"); - + "dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, "Compute int8 quantized tensor and scaling factor"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 9d568423b0045..62cc23cd242f2 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,6 +3,7 @@ #include #include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -25,33 +26,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { namespace vllm { -// TODO (varun) : Merge this into reduction utils and use the existing interface -// TODO (varun) : Add unit tests for this -template -__inline__ __device__ T warpReduceMax(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); - return val; -} - -/* Calculate the maximum of all elements in a block */ -template -__inline__ __device__ T blockReduceMax(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - val = warpReduceMax(val); // get maxx in each warp - if (lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - __syncthreads(); - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; - val = warpReduceMax(val); - return val; -} - template __global__ void static_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, @@ -66,7 +40,7 @@ __global__ void static_scaled_int8_quant_kernel( } template -__global__ void quant_per_token_kernel( +__global__ void dynamic_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, scale_type scale, @@ -121,7 +95,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] }); } -void quant_per_token( +void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& scales) { @@ -132,8 +106,8 @@ void quant_per_token( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_per_token_kernel", [&] { - vllm::quant_per_token_kernel<<>>( + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { + vllm::dynamic_scaled_int8_quant_kernel<<>>( input.data_ptr(), out.data_ptr(), scales.data_ptr(), diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 9af4aae516151..e28f4c99ad926 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -21,31 +21,48 @@ #include "cuda_compat.h" namespace vllm { -template -__inline__ __device__ T warpReduceSum(T val) { - static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, - "numLanes is not a positive power of 2!"); - static_assert(numLanes <= WARP_SIZE); -#pragma unroll - for (int mask = numLanes >> 1; mask > 0; mask >>= 1) - val += VLLM_SHFL_XOR_SYNC(val, mask); - return val; + +namespace detail { + +template +__inline__ __device__ T _max(T a, T b) { + return max(a, b); } +template +__inline__ __device__ T _sum(T a, T b) { + return a + b; +} + +} // detail + +template +using ReduceFnType = T(*)(T, T); + // Helper function to return the next largest power of 2 static constexpr int _nextPow2(unsigned int num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) { +template +__inline__ __device__ T warpReduce(T val, ReduceFnType fn) { + static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, + "numLanes is not a positive power of 2!"); + static_assert(numLanes <= WARP_SIZE); + #pragma unroll + for (int mask = numLanes >> 1; mask > 0; mask >>= 1) + val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); + + return val; +} + +template +__inline__ __device__ T blockReduce(T val, ReduceFnType fn) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { - val = warpReduceSum(val); - // Calculates max number of lanes that need to participate in the last - // warpReduce + val = warpReduce(val, fn); + // Calculates max number of lanes that need to participate in the last warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; @@ -54,14 +71,23 @@ __inline__ __device__ T blockReduceSum(T val) { __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] - : (T)(0.0f); - val = warpReduceSum(val); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = warpReduce(val, fn); } else { // A single warpReduce is equal to blockReduce - val = warpReduceSum(val); + val = warpReduce(val, fn); } return val; } -} // namespace vllm +template +__inline__ __device__ T blockReduceMax(T val) { + return blockReduce(val, detail::_max); +} + +template +__inline__ __device__ T blockReduceSum(T val) { + return blockReduce(val, detail::_sum); +} + +} // namespace vllm diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index b87c2f0631201..ea47ecb464b39 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -100,7 +100,7 @@ def _quantize_activation(self, x: torch.Tensor): input_scales = torch.empty((x.numel() // x.shape[-1], 1), dtype=x.dtype, device=x.device) - ops.quant_per_token(x_q, x, input_scales) + ops.dynamic_scaled_int8_quant(x_q, x, input_scales) return x_q, input_scales def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): From 66719a9d08014be3b8e2915f6df1efd19a0e2d9a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 16 May 2024 17:52:05 +0000 Subject: [PATCH 53/73] add test_int8_quant --- tests/kernels/test_int8_quant.py | 47 ++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b9aa00ce13f56..66f679f28b155 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -10,22 +10,53 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + x_token_max, _ = x.max(dim=1) + x_token_max = x_token_max.to(dtype=torch.float32) + scales = (x_token_max / float(127.0))[:, None].to(device="cuda", + dtype=torch.float32) + torch_out = (x / scales).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + + ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") + scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") + ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) + + assert torch.allclose(scales_out, scales) + assert torch.allclose(torch_out, ops_out, + atol=1) # big atol to account for rounding errors + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() -def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, - seed: int, scale: float) -> None: +def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - out1 = (x / scale).round().clamp( - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) - ops.static_scaled_int8_quant(out2, x, scale) - assert torch.allclose(out1, out2, + torch_out = (x / scale).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + ops_out = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(ops_out, x, scale) + assert torch.allclose(torch_out, ops_out, atol=1) # big atol to account for rounding errors From 2ec6a2c99d9c576d2e10641abd7803db63970740 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 24 May 2024 17:39:05 +0000 Subject: [PATCH 54/73] remove fake quant --- .../compressed_tensors/compressed_tensors.py | 12 ++++------ .../compressed_tensors_w8a8_dynamictoken.py | 17 +------------ .../compressed_tensors_w8a8_statictensor.py | 24 ------------------- 3 files changed, 6 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f4689f5eeb65..4bf18c2b0e68f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, - CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) + CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, + CompressedTensorsW8A8StaticTensor) class CompressedTensorsConfig(QuantizationConfig): @@ -113,12 +113,10 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8StaticTensor( - fake_quant=self.fake_quant) + return CompressedTensorsW8A8StaticTensor() elif self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8DynamicToken( - fake_quant=self.fake_quant) + return CompressedTensorsW8A8DynamicToken() raise NotImplementedError("Scheme not supported.") @@ -127,7 +125,7 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": # TODO: update/map layer_name for llama models before # using find_first_name_or_class_match? layer_type_name = find_first_name_or_class_match( - name=layer_name, + name="", module=layer, targets=self.layer_quant_details.keys(), check_contains=True) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index ea47ecb464b39..95956e577e032 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -13,9 +13,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): - def __init__(self, fake_quant: bool): - self.fake_quant = fake_quant - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: if isinstance(shard_id, int): return shard_id @@ -55,12 +52,10 @@ def create_weights(self, layer: torch.nn.Module, dtype=torch.float32), requires_grad=False) - if not self.fake_quant: - params_dtype = torch.int8 weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, device="cuda", - dtype=params_dtype), + dtype=torch.int8), requires_grad=False) layer.register_parameter("weight", weight) @@ -111,15 +106,5 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): cutlass_gemm_dq) x_q, input_scales = self._quantize_activation(x) - if self.fake_quant: - logical_widths = weight.logical_widths - w_scales = [ - weight_scale[sum(logical_widths[:i])].item() - for i in range(len(logical_widths)) - ] - w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) - w_q = self._quantize_weights(weight, w_scales, logical_widths) - return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, - input_scales) return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, input_scales) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index f639564fa7967..d16e570d12202 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -108,36 +108,12 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): -<<<<<<< HEAD -======= - # Lazy import so we don't fail on cutlass imports on non-CUDA - # machines. - from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 - cutlass_gemm_dq) - ->>>>>>> d8155e1e (fix shape for cutlass issues) weight = layer.weight weight_scale = layer.weight_scale act_scale = layer.input_scale # Input quantize -<<<<<<< HEAD x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) -======= - x_q = self._quantize_activation(x, act_scale[0].item()) - - if self.fake_quant: - logical_widths = weight.logical_widths - w_scales = [ - weight_scale[sum(logical_widths[:i])].item() - for i in range(len(logical_widths)) - ] - w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu")) - w_q = self._quantize_weights(weight, w_scales, logical_widths) - # GEMM and dq - return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale) - return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale) ->>>>>>> d8155e1e (fix shape for cutlass issues) From 34e2e12c45b11257effb848453e713e9398565fe Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 24 May 2024 18:01:34 +0000 Subject: [PATCH 55/73] format --- tests/kernels/test_int8_quant.py | 3 ++- .../quantization/compressed_tensors/compressed_tensors.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 1c7aae093ea92..66f679f28b155 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -58,4 +58,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits.max).to(torch.int8) ops_out = torch.empty_like(x, dtype=torch.int8) ops.static_scaled_int8_quant(ops_out, x, scale) - assert torch.allclose(torch_out, ops_out, atol=1) # big atol to account for rounding errors + assert torch.allclose(torch_out, ops_out, + atol=1) # big atol to account for rounding errors diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 7c2dd5a9bbd21..fa25f6fe26b1b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -37,7 +37,6 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): if k in layer: layer.replace(k, llama_mapping.get(k, k)) - def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -135,7 +134,7 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( layer_type_name, None) - + if layer_quant_details is None: raise ValueError( f"Could not find quantization details for {layer}.") From e79517e15f3a08c15175a9fff5987109495b8349 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 20:43:22 +0000 Subject: [PATCH 56/73] combine static and dynamic quant computation --- vllm/_custom_ops.py | 25 ++++++---- .../compressed_tensors_w8a8_dynamictoken.py | 48 ++++++------------- .../compressed_tensors_w8a8_statictensor.py | 2 +- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f0fab4d8aa26d..d8bae987a7197 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -252,22 +252,31 @@ def scaled_fp8_quant( # int8 -def static_scaled_int8_quant(input: torch.Tensor, - scale: float) -> torch.Tensor: +def scaled_int8_quant(input: torch.Tensor, + scale: Optional[float] = None +) -> Tuple[torch.Tensor, torch.Tensor | float] : """ - Quantize the input tensor to int8 and return the quantized tensor. + Quantize the input tensor to int8 and return the quantized tensor and scale. Args: input: The input tensor to be quantized to int8. - scale: Scaling factor for the int8 quantization. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. Returns: - torch.Tensor: Output tensor in int8. + Tuple[Torch.Tensor, Torch.Tensor | float] : Output int8 tensor and scales. """ q = torch.empty_like(input, dtype=torch.int8) - vllm_ops.static_scaled_int8_quant(q, input, scale) - return q - + if scale is not None: + # Static-per-tensor quantization. + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q, scale + + # Dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + dtype=torch.float32) + vllm_ops.dynamic_scaled_int8_quant(q, input, input_scales) + return q, input_scales # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index 95956e577e032..fa629894b58ee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -3,7 +3,7 @@ import torch from torch.nn import Parameter -from vllm._C import ops +from vllm import _custom_ops as custom_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.utils import set_weight_attrs @@ -39,16 +39,23 @@ def create_weights(self, layer: torch.nn.Module, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + # When the scales have a single value, it is required that they be + # on the CPU for performance and CUDA Graphs compatibility. Please + # refer to the comment in + # CompressedTensorsW8A8StaticTensor::create_weights for further + # information. is_tensor_partitioned = len(output_partition_sizes) != 1 - dim = sum(output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" weight_zero_point = Parameter(torch.empty(1, device="cuda", dtype=torch.int8), requires_grad=False) - weight_scale = Parameter(torch.empty(dim, - device="cuda", + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, dtype=torch.float32), requires_grad=False) @@ -74,37 +81,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_zero_point", weight_zero_point) set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) - def _quantize_weights(self, - x: torch.Tensor, - scales: torch.Tensor, - logical_widths: List[int], - split_dim: int = 0) -> torch.Tensor: - - x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") - x_q_split = x_q.split(logical_widths, dim=split_dim) - x_split = x.split(logical_widths, dim=split_dim) - - for q, dq, scale in zip(x_q_split, x_split, scales): - ops.quant_per_tensor(q, dq, scale.item()) - - return x_q - - # Determine per token input scales on the fly - def _quantize_activation(self, x: torch.Tensor): - x_q = torch.empty_like(x, dtype=torch.int8) - input_scales = torch.empty((x.numel() // x.shape[-1], 1), - dtype=x.dtype, - device=x.device) - ops.dynamic_scaled_int8_quant(x_q, x, input_scales) - return x_q, input_scales - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight weight_scale = layer.weight_scale - from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501 - cutlass_gemm_dq) - - x_q, input_scales = self._quantize_activation(x) - return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, - input_scales) + x_q, input_scales = custom_ops.scaled_int8_quant(x) + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales, + weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index d16e570d12202..6ecffd8a53d01 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -113,7 +113,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): act_scale = layer.input_scale # Input quantize - x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + x_q, _ = custom_ops.scaled_int8_quant(x, act_scale[0].item()) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From 39e66d1fd111ddd72d5d32b59e67314c884f486d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 20:50:37 +0000 Subject: [PATCH 57/73] TORCH_CHECK and nits --- csrc/ops.h | 9 +++------ csrc/pybind.cpp | 5 ++--- .../compressed_tensors/int8_quant_kernels.cu | 4 ++-- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 883cf138def83..66451e13762d1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -96,6 +96,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, float scale); +void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scales); + void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); @@ -117,12 +120,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -void dynamic_scaled_int8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scales); - - #ifndef USE_ROCM using fptr_t = uint64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 0123c86d10edd..547823aa1b04e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -70,9 +70,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); - ops.def( - "dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, "Compute int8 quantized tensor and scaling factor"); - + ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, + "Compute int8 quantized tensor and scaling factor"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 1affc404d63ad..b1cf22d72bc97 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -102,8 +102,8 @@ void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& scales) { - assert(input.is_contiguous()); - assert(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); From 59f8ec18d55f0f3044a552604c63713ccdcc4d6b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 20:51:51 +0000 Subject: [PATCH 58/73] use Union --- vllm/_custom_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d8bae987a7197..799910e87dab5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, Union import torch @@ -252,9 +252,10 @@ def scaled_fp8_quant( # int8 -def scaled_int8_quant(input: torch.Tensor, - scale: Optional[float] = None -) -> Tuple[torch.Tensor, torch.Tensor | float] : +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[float] = None +) -> Tuple[torch.Tensor, Union[torch.Tensor, float]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale. @@ -278,6 +279,7 @@ def scaled_int8_quant(input: torch.Tensor, vllm_ops.dynamic_scaled_int8_quant(q, input, input_scales) return q, input_scales + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, From 7a83601cfe57c3d7d7a7b3f49fbb4a8231f6832e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 20:52:58 +0000 Subject: [PATCH 59/73] clang-format --- .../compressed_tensors/int8_quant_kernels.cu | 41 +++++++------------ csrc/reduction_utils.cuh | 28 +++++++------ 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index b1cf22d72bc97..cf62ecc715ffa 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -5,7 +5,6 @@ #include "../../dispatch_utils.h" #include "../../reduction_utils.cuh" - static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static const float i8_min = @@ -40,14 +39,10 @@ __global__ void static_scaled_int8_quant_kernel( } } - template __global__ void dynamic_scaled_int8_quant_kernel( - const scalar_t* __restrict__ input, - int8_t* __restrict__ out, - scale_type scale, - const int hidden_size) { - + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { const int tid = threadIdx.x; const int token_idx = blockIdx.x; @@ -57,8 +52,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( for (int i = tid; i < hidden_size; i += blockDim.x) { float val = (float)input[token_idx * hidden_size + i]; val = val > zero ? val : -val; - if (val > amax_val) - amax_val = val; + if (val > amax_val) amax_val = val; } __shared__ float s_amax; @@ -71,13 +65,11 @@ __global__ void dynamic_scaled_int8_quant_kernel( float tmp_scale = 127.0f / s_amax; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = - float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale); + out[token_idx * hidden_size + i] = float_to_int8_rn( + ((float)input[token_idx * hidden_size + i]) * tmp_scale); } } - - void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] float scale) { @@ -97,11 +89,9 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] }); } - -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scales) { +void dynamic_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); @@ -109,12 +99,11 @@ void dynamic_scaled_int8_quant( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel<<>>( - input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { + vllm::dynamic_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), + scales.data_ptr(), hidden_size); + }); } - diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index e28f4c99ad926..08063356012b8 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -24,20 +24,20 @@ namespace vllm { namespace detail { -template +template __inline__ __device__ T _max(T a, T b) { return max(a, b); } -template +template __inline__ __device__ T _sum(T a, T b) { return a + b; } -} // detail +} // namespace detail -template -using ReduceFnType = T(*)(T, T); +template +using ReduceFnType = T (*)(T, T); // Helper function to return the next largest power of 2 static constexpr int _nextPow2(unsigned int num) { @@ -45,24 +45,25 @@ static constexpr int _nextPow2(unsigned int num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template +template __inline__ __device__ T warpReduce(T val, ReduceFnType fn) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); - #pragma unroll +#pragma unroll for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); return val; } -template +template __inline__ __device__ T blockReduce(T val, ReduceFnType fn) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduce(val, fn); - // Calculates max number of lanes that need to participate in the last warpReduce + // Calculates max number of lanes that need to participate in the last + // warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; @@ -71,7 +72,8 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn) { __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); val = warpReduce(val, fn); } else { // A single warpReduce is equal to blockReduce @@ -80,14 +82,14 @@ __inline__ __device__ T blockReduce(T val, ReduceFnType fn) { return val; } -template +template __inline__ __device__ T blockReduceMax(T val) { return blockReduce(val, detail::_max); } -template +template __inline__ __device__ T blockReduceSum(T val) { return blockReduce(val, detail::_sum); } -} // namespace vllm +} // namespace vllm From 9ea47c892d3ee5a67ce9c51d5b4ce0a1c17f4837 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 21:17:13 +0000 Subject: [PATCH 60/73] fix typo --- csrc/quantization/compressed_tensors/int8_quant_kernels.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index cf62ecc715ffa..ae1e3f216fa76 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -70,6 +70,8 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +} // namespace vllm + void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] float scale) { From 7abb2c853af11aea6568fe0847165d5fdf3ddeb9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 24 May 2024 21:42:37 +0000 Subject: [PATCH 61/73] isort --- .../layers/quantization/compressed_tensors/compressed_tensors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index fa25f6fe26b1b..9af60360d3e5a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch - from compressed_tensors.quantization.lifecycle.apply import ( find_first_name_or_class_match) from compressed_tensors.quantization.quant_args import (QuantizationArgs, From eb4e119b327e087db513af6ac05e03bc6e98321a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 May 2024 18:41:37 +0000 Subject: [PATCH 62/73] update test case --- tests/quantization/test_compressed_tensors.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index b83286992da3d..897e78098cdef 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,7 +6,8 @@ import torch from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor, + CompressedTensorsW8A8DynamicToken) def test_compressed_tensors_w8a8_static_setup(vllm_runner): @@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): assert qkv_proj.weight_scale.shard_splitter is not None assert qkv_proj.weight_scale.logical_widths is not None assert qkv_proj.input_scale.dtype is torch.float32 + + +def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-dynamic-test" + llm = vllm_runner(model_path, + quantization="sparseml", + enforce_eager=True, + dtype=torch.float16) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) + assert qkv_proj.weight.dtype is torch.int8 From d62930d2aa28099fb7fc23766b7dee28acc3a8a2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 May 2024 18:45:43 +0000 Subject: [PATCH 63/73] fix isort --- tests/quantization/test_compressed_tensors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 897e78098cdef..8b48f418fe49f 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,8 +6,8 @@ import torch from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor, - CompressedTensorsW8A8DynamicToken) + CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, + CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): From 80b6facdd3d563ddd50cd10151d8064c36f22fed Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 29 May 2024 19:36:20 +0000 Subject: [PATCH 64/73] store input scales in gpu --- vllm/_custom_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 799910e87dab5..675e622ba4835 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -275,7 +275,8 @@ def scaled_int8_quant( # Dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), - dtype=torch.float32) + dtype=torch.float32, + device="cuda") vllm_ops.dynamic_scaled_int8_quant(q, input, input_scales) return q, input_scales From 7075318944c7325a978b2297541fdd7c6f77cd14 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Jun 2024 13:28:09 +0000 Subject: [PATCH 65/73] tensor device location fixes --- vllm/_custom_ops.py | 6 +++--- .../schemes/compressed_tensors_w8a8_dynamictoken.py | 4 ---- .../schemes/compressed_tensors_w8a8_statictensor.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1e897794de999..c8f6150d1ee2d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -264,8 +264,8 @@ def scaled_fp8_quant( # int8 -def static_scaled_int8_quant(input: torch.Tensor, - scale: torch.Tensor) -> torch.Tensor: +def scaled_int8_quant(input: torch.Tensor, + scale: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize the input tensor to int8 and return the quantized tensor and scale. @@ -275,7 +275,7 @@ def static_scaled_int8_quant(input: torch.Tensor, When not provided, we invoke dynamic-per-token quantization. Returns: - Tuple[Torch.Tensor, Torch.Tensor | float] : Output int8 tensor and scales. + Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. """ q = torch.empty_like(input, dtype=torch.int8) if scale is not None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index fa629894b58ee..d8317f2137451 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -47,21 +47,17 @@ def create_weights(self, layer: torch.nn.Module, is_tensor_partitioned = len(output_partition_sizes) != 1 weight_scale_dim = sum( output_partition_sizes) if is_tensor_partitioned else 1 - weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" weight_zero_point = Parameter(torch.empty(1, - device="cuda", dtype=torch.int8), requires_grad=False) weight_scale = Parameter(torch.empty(weight_scale_dim, - device=weight_scale_device, dtype=torch.float32), requires_grad=False) weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, - device="cuda", dtype=torch.int8), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 2dfc6e2b07782..7559fc0f95b24 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -97,7 +97,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): act_scale = layer.input_scale # Input quantize - x_q = custom_ops.static_scaled_int8_quant(x, act_scale) + x_q, _ = custom_ops.scaled_int8_quant(x, act_scale) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From 60a6d73b7450cfdfc14bfbe8d50f684de3babe8b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Jun 2024 13:31:00 +0000 Subject: [PATCH 66/73] format.sh --- tests/kernels/test_int8_quant.py | 5 ++--- vllm/_custom_ops.py | 8 +++++--- .../schemes/compressed_tensors_w8a8_dynamictoken.py | 3 +-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 1f7faf1b8f088..0f1df66322eff 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -54,9 +54,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - out1 = (x / scale).round().clamp( - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max).to(torch.int8) + out1 = (x / scale).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) out2 = torch.empty_like(x, dtype=torch.int8) scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c8f6150d1ee2d..7cb5ce42cdcc0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, Union +from typing import Optional, Tuple, Type import torch @@ -264,8 +264,10 @@ def scaled_fp8_quant( # int8 -def scaled_int8_quant(input: torch.Tensor, - scale: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize the input tensor to int8 and return the quantized tensor and scale. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index d8317f2137451..25b707caeef33 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -48,8 +48,7 @@ def create_weights(self, layer: torch.nn.Module, weight_scale_dim = sum( output_partition_sizes) if is_tensor_partitioned else 1 - weight_zero_point = Parameter(torch.empty(1, - dtype=torch.int8), + weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8), requires_grad=False) weight_scale = Parameter(torch.empty(weight_scale_dim, From f36519b515b5d192ca3bfe94b880e3806ea0d90a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 6 Jun 2024 14:49:31 +0000 Subject: [PATCH 67/73] remove compressed tensors --- requirements-common.txt | 1 - .../compressed_tensors/compressed_tensors.py | 6 +- .../quantization/compressed_tensors/utils.py | 96 +++++++++++++++++++ 3 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/utils.py diff --git a/requirements-common.txt b/requirements-common.txt index d1651ec3076b5..f41873570aa67 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,4 +20,3 @@ lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -compressed-tensors == 0.3.3 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 9af60360d3e5a..db59dfbc8792d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,10 +1,6 @@ from typing import Any, Dict, List, Optional import torch -from compressed_tensors.quantization.lifecycle.apply import ( - find_first_name_or_class_match) -from compressed_tensors.quantization.quant_args import (QuantizationArgs, - QuantizationStrategy) from pydantic import BaseModel from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase @@ -13,6 +9,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) class CompressedTensorsConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py new file mode 100644 index 0000000000000..01440cb5c6d88 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -0,0 +1,96 @@ +import re +from torch.nn import Module +from enum import Enum +from typing import Any, Dict, Iterable, Optional + +from pydantic import BaseModel, Field + +class QuantizationType(str, Enum): + """ + Enum storing quantization type options + """ + + INT = "int" + FLOAT = "float" + +class QuantizationStrategy(str, Enum): + """ + Enum storing quantization strategy options + """ + + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + BLOCK = "block" + TOKEN = "token" + +class QuantizationArgs(BaseModel): + """ + User facing arguments used to define a quantization config for weights or + activations + + :param num_bits: quantization bit depth + :param type: dtype to quantized to, either int or float + :param symmetric: whether or not quantization scale is symmetric about zero-point + :param strategy: string id determining the scope of scale/zero-point to apply + :param group_size: group length to use for the group strategy + :param block_structure: 2d block structure to use for the block strategy, must be + of the format "2x4", "8x16", etc. + :param dynamic: set True to perform dynamic quantization - values will not be + calibrated during calibration phase, instead during inference new quantization + ranges will be observed with every sample. Defaults to False for static + quantization. Note that enabling dynamic quantization will change the default + observer to a memoryless one + """ + + num_bits: int = 8 + type: QuantizationType = QuantizationType.INT + symmetric: bool = True + group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None + block_structure: Optional[str] = None + dynamic: bool = False + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization param - " + "scale and zero-point'" + ), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + +def find_first_name_or_class_match( + name: str, module: Module, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: + # first element of targets that matches the given name + # if no name matches returns first target that matches the class name + # returns None otherwise + return _find_first_match(name, targets) or _find_first_match( + module.__class__.__name__, targets, check_contains + ) + +def _find_first_match( + value: str, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: + # returns first element of target that matches value either + # exactly or as a regex after 're:'. if check_contains is set to True, + # additionally checks if the target string is contained with value. + + for target in targets: + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return target + elif check_contains: + if target.lower() in value.lower(): + return target + elif target == value: + return target + return None \ No newline at end of file From 2c6e5804f2267a5bc3dc1198f2703f7c37cd71ee Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 6 Jun 2024 14:56:43 +0000 Subject: [PATCH 68/73] format fix --- .../quantization/compressed_tensors/utils.py | 59 ++++++++++--------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 01440cb5c6d88..2aff5aa79812f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,9 +1,10 @@ import re -from torch.nn import Module from enum import Enum from typing import Any, Dict, Iterable, Optional from pydantic import BaseModel, Field +from torch.nn import Module + class QuantizationType(str, Enum): """ @@ -13,6 +14,7 @@ class QuantizationType(str, Enum): INT = "int" FLOAT = "float" + class QuantizationStrategy(str, Enum): """ Enum storing quantization strategy options @@ -24,23 +26,25 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" + class QuantizationArgs(BaseModel): """ - User facing arguments used to define a quantization config for weights or - activations + User facing arguments used to define a quantization config + for weights or activations :param num_bits: quantization bit depth :param type: dtype to quantized to, either int or float - :param symmetric: whether or not quantization scale is symmetric about zero-point - :param strategy: string id determining the scope of scale/zero-point to apply + :param symmetric: whether or not quantization scale is symmetric + :param strategy: string determining the scope of scale/zero-point to apply :param group_size: group length to use for the group strategy - :param block_structure: 2d block structure to use for the block strategy, must be - of the format "2x4", "8x16", etc. - :param dynamic: set True to perform dynamic quantization - values will not be - calibrated during calibration phase, instead during inference new quantization - ranges will be observed with every sample. Defaults to False for static - quantization. Note that enabling dynamic quantization will change the default - observer to a memoryless one + :param block_structure: 2d block structure to use for the block + strategy, must be of the format "2x4", "8x16", etc. + :param dynamic: set True to perform dynamic quantization - + values will not be calibrated during calibration phase, + instead during inference new quantization ranges will be + observed with every sample. Defaults to False for static + quantization. Note that enabling dynamic quantization + will change the default observer to a memoryless one """ num_bits: int = 8 @@ -52,33 +56,32 @@ class QuantizationArgs(BaseModel): dynamic: bool = False observer: str = Field( default="minmax", - description=( - "The class to use to compute the quantization param - " - "scale and zero-point'" - ), + description=("The class to use to compute the quantization param - " + "scale and zero-point'"), ) observer_kwargs: Dict[str, Any] = Field( default_factory=dict, - description=( - "optional dict of kwargs to be passed directly to torch quantization " - "Observers constructor excluding quantization range or symmetry" - ), + description= + ("optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry"), ) def find_first_name_or_class_match( - name: str, module: Module, targets: Iterable[str], check_contains: bool = False -) -> Optional[str]: + name: str, + module: Module, + targets: Iterable[str], + check_contains: bool = False) -> Optional[str]: # first element of targets that matches the given name # if no name matches returns first target that matches the class name # returns None otherwise return _find_first_match(name, targets) or _find_first_match( - module.__class__.__name__, targets, check_contains - ) + module.__class__.__name__, targets, check_contains) + -def _find_first_match( - value: str, targets: Iterable[str], check_contains: bool = False -) -> Optional[str]: +def _find_first_match(value: str, + targets: Iterable[str], + check_contains: bool = False) -> Optional[str]: # returns first element of target that matches value either # exactly or as a regex after 're:'. if check_contains is set to True, # additionally checks if the target string is contained with value. @@ -93,4 +96,4 @@ def _find_first_match( return target elif target == value: return target - return None \ No newline at end of file + return None From b3d692a983b6481828eac0e8bb6211caeef9c91a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 6 Jun 2024 16:02:03 +0000 Subject: [PATCH 69/73] add comments; some clean-up --- .../compressed_tensors/compressed_tensors.py | 19 ------------- .../quantization/compressed_tensors/utils.py | 27 ++++++++++++++----- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index db59dfbc8792d..dde7acbe81eac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -19,21 +19,6 @@ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): self.ignore = ignore self.layer_quant_details = layer_quant_details - llama_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj" - } - - # Update the ignore list: e.g layers with q_proj are replaced - # to be qkv_proj to be compatible with vllm - for layer in self.ignore: - for k in llama_mapping: - if k in layer: - layer.replace(k, llama_mapping.get(k, k)) - def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -59,7 +44,6 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - layer_quant_details: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) @@ -118,8 +102,6 @@ def _get_schema(self, weight_quant: BaseModel, def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": - # TODO: update/map layer_name for llama models before - # using find_first_name_or_class_match? layer_type_name = find_first_name_or_class_match( name="", module=layer, @@ -131,7 +113,6 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( layer_type_name, None) - if layer_quant_details is None: raise ValueError( f"Could not find quantization details for {layer}.") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 2aff5aa79812f..fcc6649101845 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -72,9 +72,18 @@ def find_first_name_or_class_match( module: Module, targets: Iterable[str], check_contains: bool = False) -> Optional[str]: - # first element of targets that matches the given name - # if no name matches returns first target that matches the class name - # returns None otherwise + """ + Helper function to map the quantization details listed in the config + for a given list of targets against each model layer. First uses the + layer name to try and find a match. If no name match is found, uses + the layer class name. Returns None otherwise. + + :param name: layer name + :param module: torch.nn.Module + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ + return _find_first_match(name, targets) or _find_first_match( module.__class__.__name__, targets, check_contains) @@ -82,9 +91,15 @@ def find_first_name_or_class_match( def _find_first_match(value: str, targets: Iterable[str], check_contains: bool = False) -> Optional[str]: - # returns first element of target that matches value either - # exactly or as a regex after 're:'. if check_contains is set to True, - # additionally checks if the target string is contained with value. + """ + Returns first element of target that matches value either + exactly or as a regex after 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + + :param value: string to compare the list of targets against + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ for target in targets: if target.startswith("re:"): From f3bf9e347385797dae961f81d8c53dac3f55d8ed Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Jun 2024 16:51:14 +0000 Subject: [PATCH 70/73] review comments --- vllm/_custom_ops.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7cb5ce42cdcc0..ddcd132079e30 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -279,18 +279,18 @@ def scaled_int8_quant( Returns: Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. """ - q = torch.empty_like(input, dtype=torch.int8) + output = torch.empty_like(input, dtype=torch.int8) if scale is not None: - # Static-per-tensor quantization. - vllm_ops.static_scaled_int8_quant(q, input, scale) - return q, scale + # static-per-tensor quantization. + vllm_ops.static_scaled_int8_quant(output, input, scale) + return output, scale - # Dynamic-per-token quantization. + # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), - dtype=torch.float32, - device="cuda") - vllm_ops.dynamic_scaled_int8_quant(q, input, input_scales) - return q, input_scales + device=input.device, + dtype=torch.float32) + vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) + return output, input_scales # moe From 2bd62e09bbe6c0d2c0d251758108677ebdf562f9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Jun 2024 19:41:10 +0000 Subject: [PATCH 71/73] review comments and const correctness --- csrc/ops.h | 2 +- .../compressed_tensors/int8_quant_kernels.cu | 64 +++++++++---------- tests/kernels/test_int8_quant.py | 2 +- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 873de27b799b6..06b60e748886f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -97,7 +97,7 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); -void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, +void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index c7644dfa0350c..1ca7e90d044ff 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -28,46 +28,45 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( - const scalar_t* __restrict__ input, int8_t* __restrict__ out, - const scale_type* scale_ptr, const int hidden_size) { - const int tid = threadIdx.x; - const int token_idx = blockIdx.x; - scale_type scale = *scale_ptr; + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = - float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + float_to_int8_rn(static_cast(input[token_idx * hidden_size + i]) / scale); } } template __global__ void dynamic_scaled_int8_quant_kernel( - const scalar_t* __restrict__ input, int8_t* __restrict__ out, - scale_type scale, const int hidden_size) { - const int tid = threadIdx.x; - const int token_idx = blockIdx.x; - - float amax_val = 0.0f; - const float zero = 0.0f; + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + float absmax_val = 0.0f; + float const zero = 0.0f; for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = (float)input[token_idx * hidden_size + i]; + float val = static_cast(input[token_idx * hidden_size + i]); val = val > zero ? val : -val; - if (val > amax_val) amax_val = val; + absmax_val = val > absmax_val ? val : absmax_val; } - __shared__ float s_amax; - const float block_amax_val = blockReduceMax(amax_val); + float const block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; if (tid == 0) { - s_amax = block_amax_val; - scale[token_idx] = block_amax_val / 127.0f; + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; } __syncthreads(); - float tmp_scale = 127.0f / s_amax; + float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - ((float)input[token_idx * hidden_size + i]) * tmp_scale); + out[token_idx * hidden_size + i] = + float_to_int8_rn(static_cast(input[token_idx * hidden_size + i]) * tmp_scale); } } @@ -80,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -95,18 +94,19 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] } void dynamic_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel + vllm::dynamic_scaled_int8_quant_kernel <<>>(input.data_ptr(), out.data_ptr(), scales.data_ptr(), hidden_size); diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 0f1df66322eff..c474c0f3fed8c 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,7 +4,7 @@ from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] From 460f514b15ac4591f1c0b0a52dd3512f0abe4014 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Jun 2024 19:42:35 +0000 Subject: [PATCH 72/73] format.sh --- .../compressed_tensors/int8_quant_kernels.cu | 19 ++++++++++--------- tests/kernels/test_int8_quant.py | 3 ++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 1ca7e90d044ff..280b0327111da 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -35,14 +35,14 @@ __global__ void static_scaled_int8_quant_kernel( scale_type const scale = *scale_ptr; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = - float_to_int8_rn(static_cast(input[token_idx * hidden_size + i]) / scale); + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) / scale); } } template __global__ void dynamic_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -65,8 +65,8 @@ __global__ void dynamic_scaled_int8_quant_kernel( float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = - float_to_int8_rn(static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) * tmp_scale); } } @@ -80,7 +80,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(scale.numel() == 1); int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; + int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -93,9 +93,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] }); } -void dynamic_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index c474c0f3fed8c..aab7af9d2cbf6 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,7 +4,8 @@ from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, + 8193] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] From dfcd61a0d51d74c432ff5e8bb9a789acb1d427a8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 7 Jun 2024 02:19:42 +0000 Subject: [PATCH 73/73] nit fixes --- .../compressed_tensors/compressed_tensors.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index dde7acbe81eac..d2b0ce0dbbf0b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -72,9 +72,7 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel, is_symmetric = weight_quant.symmetric and input_quant.symmetric is_static = not weight_quant.dynamic and not input_quant.dynamic - if is_8_bits and is_tensor and is_symmetric and is_static: - return True - return False + return is_8_bits and is_tensor and is_symmetric and is_static def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -86,16 +84,14 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, is_symmetric = weight_quant.symmetric and input_quant.symmetric is_dynamic = not weight_quant.dynamic and input_quant.dynamic - if is_8_bits and is_token_tensor and is_symmetric and is_dynamic: - return True - return False + return is_8_bits and is_token_tensor and is_symmetric and is_dynamic def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8StaticTensor() - elif self._is_dynamic_token_w8a8(weight_quant, input_quant): + if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8DynamicToken() raise NotImplementedError("Scheme not supported.")