From 6a458f0173af404a25332c49c4fe9a0127743649 Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Sat, 9 Mar 2024 08:04:37 +0800 Subject: [PATCH] weight only quantization (#1349) * Update weight only quantization config Signed-off-by: Cheng Penghui --- .../quantization/run_generation_gpu_woq.py | 9 ++++++--- .../llm/quantization/utils.py | 2 +- .../transformers/modeling/modeling_auto.py | 7 +++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index de5e25f06cb..15d4cae794d 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -176,10 +176,12 @@ args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \ if user_model is None else user_model user_model = user_model.to(memory_format=torch.channels_last) + if quantization_config is None: + quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model) if not args.disable_optimize_transformers: print("Optimize with IPEX...") user_model = ipex.optimize_transformers( - user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype) + user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype) else: print("Disabled optimization with IPEX...") # start @@ -263,10 +265,12 @@ user_model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \ if user_model is None else user_model + if quantization_config is None: + quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model) if not args.disable_optimize_transformers: print("Optimize with IPEX...") user_model = ipex.optimize_transformers( - user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype) + user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype) else: print("Disabled optimization with IPEX...") results = evaluate( @@ -287,4 +291,3 @@ print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"])) else: print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"])) - diff --git a/intel_extension_for_transformers/llm/quantization/utils.py b/intel_extension_for_transformers/llm/quantization/utils.py index 36c200a324b..ad1d81db593 100644 --- a/intel_extension_for_transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/llm/quantization/utils.py @@ -135,7 +135,7 @@ def _replace_linear( ) elif device == "xpu" or device == torch.device("xpu"): from intel_extension_for_pytorch.nn.utils._quantize_convert \ - import WeightOnlyLinear as ipex_linear # pylint: disable=E0401 + import WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401 model._modules[name] = ipex_linear( in_features, out_features, diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index c707b9868ea..9d60e639943 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -70,9 +70,12 @@ def convert_model_to_public(model): - from intel_extension_for_pytorch.nn.utils._quantize_convert import WeightOnlyLinear # pylint: disable=E0401 + # pylint: disable=E0401 + from intel_extension_for_pytorch.nn.utils._quantize_convert import( + WeightOnlyQuantizedLinear + ) for name, module in model.named_modules(): - if isinstance(module, WeightOnlyLinear): + if isinstance(module, WeightOnlyQuantizedLinear): if module.weight_transposed: module.qweight.data = module.qweight.t_().contiguous() module.scales.data = module.scales.t_().contiguous()