diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f4f651f06..06792ff20 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -30,7 +30,8 @@ from ..nn_modules.qlinear.qlinear_qbits import QBitsQuantLinear, qbits_dtype from ..quantization import GPTQ, QuantizeConfig from ..quantization.config import (FORMAT, FORMAT_FIELD_JSON, META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, - MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig) + MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig, META_FIELD_URI, + META_VALUE_URI, META_FIELD_DAMP_PERCENT, META_FIELD_DAMP_AUTO_INCREMENT) from ..utils.backend import BACKEND from ..utils.data import collate_data from ..utils.device import check_cuda @@ -639,6 +640,21 @@ def save_quantized( version=__version__, ) + self.quantize_config.meta_set( + key=META_FIELD_URI, + value=META_VALUE_URI, + ) + + self.quantize_config.meta_set( + key=META_FIELD_DAMP_PERCENT, + value=self.quantize_config.damp_percent + ) + + self.quantize_config.meta_set( + key=META_FIELD_DAMP_AUTO_INCREMENT, + value=self.quantize_config.damp_auto_increment + ) + # The config, quantize_config and model may be edited in place in save_quantized. config = copy.deepcopy(self.model.config) quantize_config = copy.deepcopy(self.quantize_config) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index c9ce71598..193d13d78 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -33,6 +33,12 @@ META_QUANTIZER_GPTQMODEL = "gptqmodel" +META_FIELD_URI = "uri" +META_VALUE_URI = "https://github.com/modelcloud/gptqmodel" + +META_FIELD_DAMP_PERCENT = "damp_percent" +META_FIELD_DAMP_AUTO_INCREMENT = "damp_auto_increment" + # pkg names PKG_AUTO_ROUND = "auto-round" @@ -338,8 +344,6 @@ def to_dict(self): "static_groups": self.static_groups, "sym": self.sym, "lm_head": self.lm_head, - "damp_percent": self.damp_percent, - "damp_auto_increment": self.damp_auto_increment, "true_sequential": self.true_sequential, # TODO: deprecate? "model_name_or_path": self.model_name_or_path,