From e434c8861909f5dd336d6f017069f91470e1371f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 3 Sep 2024 17:21:44 -0400 Subject: [PATCH] [Misc] Update `GPTQ` to use `vLLMParameters` (#7976) --- tests/weight_loading/models.txt | 6 + tests/weight_loading/test_weight_loading.py | 7 +- vllm/model_executor/layers/linear.py | 25 +++-- .../layers/quantization/gptq.py | 103 ++++++++++-------- .../layers/vocab_parallel_embedding.py | 9 +- vllm/model_executor/parameter.py | 5 +- 6 files changed, 93 insertions(+), 62 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 7deb2880145ca..5eee2cc534445 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -4,6 +4,12 @@ gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main +gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main +gptq, TheBloke/Llama-2-7B-GPTQ, main +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True +gptq, TechxGenus/gemma-1.1-2b-it-GPTQ, main compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index c13313df93f66..d8bca05e204c0 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -1,5 +1,7 @@ import os +import torch + MAX_MODEL_LEN = 1024 MODEL_NAME = os.environ.get("MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq") @@ -8,9 +10,12 @@ def test_weight_loading(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ with vllm_runner(model_name=MODEL_NAME, revision=REVISION, - dtype="auto", + dtype=torch.half if QUANTIZATION == "gptq" else "auto", quantization=QUANTIZATION, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=2) as model: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1163cc727762d..8df1d7595f026 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,8 +14,10 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, + PackedColumnParameter, PackedvLLMParameter, - PerTensorScaleParameter) + PerTensorScaleParameter, + RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -24,7 +26,7 @@ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", - "TPUInt8LinearMethod" + "TPUInt8LinearMethod", "GPTQLinearMethod" ] @@ -574,8 +576,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -594,9 +596,10 @@ def weight_loader_v2(self, param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: + elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -724,8 +727,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -741,12 +744,12 @@ def weight_loader_v2(self, loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: - param.load_merged_column_weight(loaded_weight=loaded_weight) + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f456286899a53..c067a76405df6 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -11,7 +11,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) class GPTQConfig(QuantizationConfig): @@ -108,6 +112,7 @@ def create_weights( **extra_weight_attrs, ): del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -138,73 +143,81 @@ def create_weights( scale_and_zero_size = input_size_per_partition // group_size scale_and_zero_input_dim = 0 - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }) - g_idx = Parameter( - torch.tensor( - [ - i // self.quant_config.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ), - requires_grad=False, - ) - # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) - qzeros = Parameter( + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, { - "input_dim": scale_and_zero_input_dim, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) - scales = Parameter( + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), - requires_grad=False, - ) - set_weight_attrs(scales, { - "input_dim": scale_and_zero_input_dim, - "output_dim": 1, - }) + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("g_idx", g_idx) - set_weight_attrs(g_idx, extra_weight_attrs) layer.register_parameter("qzeros", qzeros) - set_weight_attrs(qzeros, extra_weight_attrs) layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) layer.exllama_state = exllama_state def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass if layer.exllama_state == ExllamaState.UNINITIALIZED: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b26a3227e6931..ef6d401be2070 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -10,6 +10,7 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -370,10 +371,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If param packed on the same dim we are sharding on, then # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: + packed_factor = param.packed_factor if isinstance( + param, BasevLLMParameter) else param.pack_factor assert loaded_weight.shape[output_dim] == (self.org_vocab_size // - param.pack_factor) - start_idx = start_idx // param.pack_factor - shard_size = shard_size // param.pack_factor + param.packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor else: assert loaded_weight.shape[output_dim] == self.org_vocab_size diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 326b6ae8fee64..9ffb339ffeab3 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,3 +1,4 @@ +from fractions import Fraction from typing import Callable, Optional, Union import torch @@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter): """ def __init__(self, - packed_factor: int, + packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, **kwargs): @@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter): """ def __init__(self, - packed_factor: int, + packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, **kwargs):