diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1cad4e55f51ee..bbc01cb301e4b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -23,7 +23,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod" + "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index ae34e01497db4..be8235b468f68 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none"] @@ -64,16 +64,16 @@ def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) def _quantize_weight( self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -92,6 +92,7 @@ def _quantize_weight( return qweight, qscale def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) device = layer.weight.device qweight, qscale = self._quantize_weight(layer.weight) qweight = qweight.to(device)