From c0a41fdf345e7bcf5bff70a8c5f7f465689d8165 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:35:04 -0400 Subject: [PATCH] fnuz support for fbgemm fp8 (#169) * fnuz support for fbgemm fp8 --- vllm/config.py | 3 ++- .../layers/quantization/fbgemm_fp8.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 854c36034e107..f967cdeb78a2d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -269,7 +269,8 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = [ - "awq", "gptq", "squeezellm", "fp8", "compressed-tensors" + "awq", "gptq", "squeezellm", "fp8", "compressed-tensors", + "fbgemm_fp8" ] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index e7c3859967c71..4aaa02e5e3972 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,9 +15,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param) + apply_fp8_linear, create_per_channel_scale_param, + normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import is_hip logger = init_logger(__name__) @@ -119,8 +121,18 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) if self.quant_config.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin.