diff --git a/auto_fp8/config.py b/auto_fp8/config.py index d32e5a8..7f8dd95 100644 --- a/auto_fp8/config.py +++ b/auto_fp8/config.py @@ -2,6 +2,23 @@ class BaseQuantizeConfig: + """Configuration for model quantization. + + Args: + quant_method: Type/precision of quantization method to use. + At the moment, this is just "fp8" which specifically means + the fp8_e4m3 format in pytorch. + activation_scheme: Choice of either "dynamic" or "static" quantization + of activtions. If "static", then calibration samples are required + during quantization to produce accurate per-tensor scales for + activations of Linear modules. + ignore_patterns: List of patterns used to ignore layers. If a string + starts with "re:", then everything afterwards is used as python + regex style matching i.e. re.search(), for each Linear layer. + By default, "re:.*lm_head" is included to ignore the embedding + Linear layer usually at the end of decoder LLMs + """ + def __init__( self, quant_method: str = "fp8",