Skip to content

Commit

Permalink
fnuz support for fbgemm fp8 (opendatahub-io#169)
Browse files Browse the repository at this point in the history
* fnuz support for fbgemm fp8
  • Loading branch information
gshtras authored Sep 4, 2024
1 parent 6d33657 commit c0a41fd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def _parse_quant_hf_config(self):
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq", "gptq", "squeezellm", "fp8", "compressed-tensors"
"awq", "gptq", "squeezellm", "fp8", "compressed-tensors",
"fbgemm_fp8"
]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
apply_fp8_linear, create_per_channel_scale_param,
normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -119,8 +121,18 @@ def create_weights(

def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=None)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
Expand Down

0 comments on commit c0a41fd

Please sign in to comment.