Skip to content

Commit

Permalink
Enhance woq model loading & support hf woq model loading
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed May 31, 2024
1 parent 855b988 commit 8f77f17
Show file tree
Hide file tree
Showing 8 changed files with 529 additions and 31 deletions.
476 changes: 469 additions & 7 deletions neural_compressor/torch/algorithms/weight_only/save_load.py

Large diffs are not rendered by default.

57 changes: 38 additions & 19 deletions neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,47 @@
}


def load(output_dir="./saved_results", model=None):
from neural_compressor.common.base_config import ConfigRegistry
def load(model_name_or_path="./saved_results", model=None, format="default", *hf_model_args, **hf_model_kwargs):
"""Load quantized model.
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
with open(qconfig_file_path, "r") as f:
per_op_qconfig = json.load(f)
Args:
model_name_or_path (str, optional): local path where quantized weights or model are saved
or huggingface model id. Defaults to "./saved_results".
model (torch.nn.Module, optional): original model. Require to pass when loading INC WOQ quantized model
or loading FP8 model. Defaults to None.
format (str, optional): 'defult' for loading INC quantized model.
'huggingface' now only for loading huggingface WOQ causal language model. Defaults to "default".
if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
from neural_compressor.torch.algorithms.static_quant import load
Returns:
torch.nn.Module: quantized model
"""
if format == "default":
from neural_compressor.common.base_config import ConfigRegistry
from neural_compressor.torch.algorithms.static_quant import load as static_quant_load
from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load
from neural_compressor.torch.algorithms.habana_fp8 import load as habana_fp8_load

return load(output_dir)
else:
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
# select load function
config_object = config_mapping[next(iter(config_mapping))]
if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
from neural_compressor.torch.algorithms.weight_only.save_load import load
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json")
with open(qconfig_file_path, "r") as f:
per_op_qconfig = json.load(f)

if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
return static_quant_load(model_name_or_path)
else:
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
# select load function
config_object = config_mapping[next(iter(config_mapping))]

return load(output_dir)
if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
return woq_load(model_name_or_path, model=model, format=format)

model.qconfig = config_mapping
if isinstance(config_object, FP8Config): # FP8
from neural_compressor.torch.algorithms.habana_fp8 import load
model.qconfig = config_mapping
if isinstance(config_object, FP8Config): # FP8
return habana_fp8_load(model, model_name_or_path)
elif format == "huggingface":
# now only support load huggingface WOQ causal language model
from neural_compressor.torch.algorithms.weight_only.save_load import load as woq_load

return load(model, output_dir) # pylint: disable=E1121
return woq_load(model_name_or_path, format=format, *hf_model_args, **hf_model_kwargs)
else:
raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format))
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_save_and_load(self):
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load("saved_results")
loaded_model = load("saved_results", model=copy.deepcopy(self.gptj))
loaded_out = loaded_model(self.inp)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."

Expand Down
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def calib_func(model):
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load("saved_results")
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
loaded_out = loaded_model(self.example_inputs)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed."
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_save_and_load(self):
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load("saved_results")
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
loaded_out = loaded_model(self.example_inputs)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
assert isinstance(
Expand Down
17 changes: 17 additions & 0 deletions test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from transformers import AutoTokenizer
from neural_compressor.torch.utils import accelerator

device = accelerator.current_device_name()

class TestHFModelLoad:
def setup_class(self):
self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ"
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to(device)

def test_load_hf_woq_model(self):
from neural_compressor.torch.quantization import load

qmodel = load(self.model_name, format="huggingface")
output = qmodel(self.example_inputs)[0]
assert len(output) > 0, "Not loading the model correctly"
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_save_and_load(self):
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load("saved_results")
loaded_model = load("saved_results", model=copy.deepcopy(self.tiny_gptj))
loaded_out = loaded_model(self.example_inputs)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."
assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed."
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_save_and_load(self):
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load("saved_results")
loaded_model = load("saved_results", model=copy.deepcopy(self.gptj))
loaded_out = loaded_model(self.example_inputs)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."

Expand Down

0 comments on commit 8f77f17

Please sign in to comment.