From b1ce77c3b93cfcca6d3a260f7572e8df13312ade Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Sat, 21 Dec 2024 20:02:37 +0800 Subject: [PATCH 1/7] [FIX] import Qwen2VLProcessor error (#951) * fix import error * fix test_qwen2_vl.py --- gptqmodel/models/definitions/qwen2_vl.py | 4 ++-- tests/models/test_qwen2_vl.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index 3f0d64c5..000ec0dc 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from PIL import Image -from transformers import AutoModelForVision2Seq, Qwen2VLProcessor +from transformers import AutoModelForVision2Seq, AutoProcessor from ..base import BaseGPTQModel from ...utils.calibration import batched @@ -82,7 +82,7 @@ def prepare_dataset( calibration_dataset, batch_size: int = 1, tokenizer=None, ): - processor = Qwen2VLProcessor.from_pretrained(self.model_id_or_path) + processor = AutoProcessor.from_pretrained(self.model_id_or_path) calib_data = [] for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): text = processor.apply_chat_template( diff --git a/tests/models/test_qwen2_vl.py b/tests/models/test_qwen2_vl.py index 9c892cfc..8c0d0f9d 100644 --- a/tests/models/test_qwen2_vl.py +++ b/tests/models/test_qwen2_vl.py @@ -33,7 +33,7 @@ def test_qwen2_vl(self): messages, tokenize=False, add_generation_prompt=True ) - image_inputs, video_inputs = Qwen2VLGPTQ.process_vision_info(messages) + image_inputs = Qwen2VLGPTQ.process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, From fdcf53bba3788485bb27a02ef6564727b349b4d2 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Sun, 22 Dec 2024 08:20:34 +0800 Subject: [PATCH 2/7] [CI] move some tests to torch2 5 (#952) * move to torch 2.5 test * Update unit_tests.yml --- .github/workflows/unit_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 80d25503..144baa1b 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -44,7 +44,7 @@ env: MAX_JOBS: 8 RUNNER: 10.0.14.248 TRANSFORMERS_DIFF_TESTS: "models/test_internlm,models/test_internlm2_5,models/test_xverse" - TORCH_2_5_TESTS: "test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba" + TORCH_2_5_TESTS: "test_evalplus,test_perplexity,test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba" IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral" GPTQMODEL_FORCE_BUILD: 1 repo: ${{ github.event.inputs.repo || github.repository }} From 748a9c760c168f0aa6813651ed871745fadd36dd Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:30:36 +0800 Subject: [PATCH 3/7] [FIX] vl model test (#953) * Use quant_override_files["preprocessor_config.json"] to process input data * qwen_vl use sample size 1 * add debug log * Revert "add debug log" This reverts commit 105b9e692b31bacd31372df9beaeed9dc14a43cf. * When calling OvisModel.generate(), you need to pass in max_new_tokens. * cleanup --- gptqmodel/models/base.py | 2 +- gptqmodel/models/definitions/qwen2_vl.py | 19 +++++++++++++++++-- tests/models/ovis/image_to_test_dataset.py | 2 +- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index c5de0c65..2833f8d3 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -466,7 +466,7 @@ def store_input_hook(_, args, kwargs): example[k] = move_to(v, cur_layer_device) try: if is_ovis: - self.generate(inputs=example.pop("input_ids"), **example) + self.generate(inputs=example.pop("input_ids"), max_new_tokens=1024, **example) else: self.model(**example) except ValueError: diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index 000ec0dc..b5ffc270 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -1,7 +1,9 @@ +import os.path +import shutil from typing import Dict, Optional from PIL import Image -from transformers import AutoModelForVision2Seq, AutoProcessor +from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from ..base import BaseGPTQModel from ...utils.calibration import batched @@ -82,7 +84,20 @@ def prepare_dataset( calibration_dataset, batch_size: int = 1, tokenizer=None, ): - processor = AutoProcessor.from_pretrained(self.model_id_or_path) + import tempfile + import json + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path) + + with tempfile.TemporaryDirectory() as tmp_dir: + chat_template_file = os.path.join(self.model_id_or_path, "chat_template.json") + if os.path.exists(chat_template_file): + shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json")) + tokenizer.save_pretrained(tmp_dir) + with open(os.path.join(tmp_dir, "preprocessor_config.json"), "w") as f: + f.write(json.dumps(self.quant_override_files["preprocessor_config.json"])) + processor = AutoProcessor.from_pretrained(tmp_dir) calib_data = [] for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): text = processor.apply_chat_template( diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index 22645c4e..bc0eccec 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -47,6 +47,6 @@ def get_calib_dataset(model): return prepare_dataset(format_ovis_dataset, n_sample=20) if isinstance(model, Qwen2VLGPTQ): - return prepare_dataset(format_qwen2_vl_dataset, n_sample=20) + return prepare_dataset(format_qwen2_vl_dataset, n_sample=1) raise NotImplementedError(f"Unsupported MODEL: {model.__class__}") From 6152c90d934520eb196a64a3f2143d49442b2c8e Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Mon, 23 Dec 2024 17:29:44 +0800 Subject: [PATCH 4/7] filter torch cuda arch < 6.0 (#955) * filter arch < 6.0 * remove unused codes --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index c5435182..395232ba 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,10 @@ TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST") +if TORCH_CUDA_ARCH_LIST: + arch_list = [arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0] + os.environ["TORCH_CUDA_ARCH_LIST"] = " ".join(arch_list) + version_vars = {} exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars) gptqmodel_version = version_vars['version'] @@ -109,6 +113,7 @@ def get_version_tag(is_cuda_release: bool = True) -> str: if got_cuda_between_v6_and_v8: FORCE_BUILD = True + if BUILD_CUDA_EXT: if CUDA_RELEASE == "1": common_setup_kwargs["version"] += f"+{get_version_tag(True)}" From 1d1d93e3b6263b53643e265521ee8c9cae8a12b7 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Mon, 23 Dec 2024 17:46:42 +0800 Subject: [PATCH 5/7] print filtered arch warning (#957) * print arch warning * if list changed, print log * move to one line --- setup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 395232ba..6efe8a72 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,10 @@ TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST") if TORCH_CUDA_ARCH_LIST: - arch_list = [arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0] - os.environ["TORCH_CUDA_ARCH_LIST"] = " ".join(arch_list) + arch_list = " ".join([arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0 or print(f"we do not support this compute arch: {arch}, skipped.") is not None]) + if arch_list != TORCH_CUDA_ARCH_LIST: + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list + print(f"TORCH_CUDA_ARCH_LIST has been updated to '{arch_list}'") version_vars = {} exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars) From 883a52a18360bf4af45a9e778228d0fd2c01fa38 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Mon, 23 Dec 2024 17:55:12 +0800 Subject: [PATCH 6/7] prepare for 1.5.0 (#958) * prepare for 1.5.0 * Update version.py * Update README.md --- README.md | 1 + gptqmodel/version.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cc26a19c..ab166810 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@

## News +* 12/23/2024 [1.5.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.5.0): Multi-modal (image-to-text) optimized quantization support added for Qwen 2-VL and Ovis 1.6-VL. Previously image-to-text model quantization was not using image calibration data and post-quant result was less than optimal. 1.5.0 is the first release to release a stable path for multi-modal quantization: note only text layers are quantized. * 12/19/2024 [1.4.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.5): Windows 11 support added/validated. Ovis VL model support with image dataset calibration. Fixed `dynamic` loading. Reduced quantization vram usage. * 12/15/2024 [1.4.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.2): MacOS `gpu` (Metal) and `cpu` (M+) support added/validated for inference and quantization. Cohere 2 model support added. * 12/13/2024 [1.4.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.1): Added Qwen2-VL model support. `mse` quantization control exposed in `QuantizeConfig`. Monkey patch `patch_vllm()` and `patch_hf()` api added to allow Transformers/Optimum/PEFT and vLLM to correctly loaded GPTQModel quantized models while upstream PRs are in pending status. diff --git a/gptqmodel/version.py b/gptqmodel/version.py index fa8ac7d6..5b601886 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -1 +1 @@ -__version__ = "1.4.6-dev" +__version__ = "1.5.0" From 21949f8cd3187a9dba1ee9e72c7315004476285f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Mon, 23 Dec 2024 18:07:59 +0800 Subject: [PATCH 7/7] [FIX] wrong filepath was used when model_id_or_path was hugging model id (#956) * fix issue: the wrong filepath was used when the model_id_or_path was a hugging model id * cleanup * BaseModel removed "model_id_or_path" --- gptqmodel/models/base.py | 6 ++-- gptqmodel/models/definitions/qwen2_vl.py | 4 +-- gptqmodel/models/loader.py | 45 ++++++++++++++++-------- gptqmodel/models/writer.py | 6 ++-- tests/models/model_test.py | 6 ++-- 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 2833f8d3..2845c339 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -101,7 +101,7 @@ def __init__( qlinear_kernel: nn.Module = None, load_quantized_model: bool = False, trust_remote_code: bool = False, - model_id_or_path: str = None, + model_local_path: str = None, ): super().__init__() @@ -114,7 +114,7 @@ def __init__( # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion self.qlinear_kernel = qlinear_kernel self.trust_remote_code = trust_remote_code - self.model_id_or_path = model_id_or_path + self.model_local_path = model_local_path # stores all per-layer quant stats such as avg loss and processing time self.quant_log = [] @@ -774,7 +774,7 @@ def save( ): extra_json_file_names = ["preprocessor_config.json", "chat_template.json"] for name in extra_json_file_names: - json_path = os.path.join(self.model_id_or_path, name) + json_path = os.path.join(self.model_local_path, name) if os.path.exists(json_path): os.makedirs(save_dir, exist_ok=True) diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index b5ffc270..ec1a2fcd 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -88,10 +88,10 @@ def prepare_dataset( import json if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path) + tokenizer = AutoTokenizer.from_pretrained(self.model_local_path) with tempfile.TemporaryDirectory() as tmp_dir: - chat_template_file = os.path.join(self.model_id_or_path, "chat_template.json") + chat_template_file = os.path.join(self.model_local_path, "chat_template.json") if os.path.exists(chat_template_file): shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json")) tokenizer.save_pretrained(tmp_dir) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 7b8aa501..d303e0b9 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from importlib.metadata import PackageNotFoundError, version from typing import Dict, List, Optional, Union @@ -36,6 +37,7 @@ verify_sharded_model_hashes, ) from ._const import DEVICE, SUPPORTED_MODELS, normalize_device +from huggingface_hub import snapshot_download logger = setup_logger() @@ -73,7 +75,7 @@ def compare_versions(installed_version, required_version, operator): raise ValueError(f"Unsupported operator: {operator}") -def check_versions(model_id_or_path: str, requirements: List[str]): +def check_versions(model_class, requirements: List[str]): if requirements is None: return for req in requirements: @@ -81,9 +83,17 @@ def check_versions(model_id_or_path: str, requirements: List[str]): try: installed_version = version(pkg) if not compare_versions(installed_version, version_required, operator): - raise ValueError(f"{model_id_or_path} requires version {req}, but current {pkg} version is {installed_version} ") + raise ValueError(f"{model_class} requires version {req}, but current {pkg} version is {installed_version} ") except PackageNotFoundError: - raise ValueError(f"{model_id_or_path} requires version {req}, but {pkg} not installed.") + raise ValueError(f"{model_class} requires version {req}, but {pkg} not installed.") + + +def get_model_local_path(pretrained_model_id_or_path, **kwargs): + is_local = os.path.isdir(pretrained_model_id_or_path) + if is_local: + return pretrained_model_id_or_path + else: + return snapshot_download(pretrained_model_id_or_path, **kwargs) def ModelLoader(cls): @@ -106,7 +116,9 @@ def from_pretrained( f"{pretrained_model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model." ) - check_versions(pretrained_model_id_or_path, cls.require_pkgs_version) + check_versions(cls, cls.require_pkgs_version) + + model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs) def skip(*args, **kwargs): pass @@ -117,7 +129,7 @@ def skip(*args, **kwargs): model_init_kwargs["trust_remote_code"] = trust_remote_code - config = AutoConfig.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs) + config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs) if torch_dtype is None or torch_dtype == "auto": torch_dtype = auto_dtype_from_config(config) @@ -130,7 +142,7 @@ def skip(*args, **kwargs): if config.model_type not in SUPPORTED_MODELS: raise TypeError(f"{config.model_type} isn't supported yet.") - model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs) + model = cls.loader.from_pretrained(model_local_path, **model_init_kwargs) model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"] @@ -149,7 +161,7 @@ def skip(*args, **kwargs): quantized=False, quantize_config=quantize_config, trust_remote_code=trust_remote_code, - model_id_or_path=pretrained_model_id_or_path + model_local_path=model_local_path, ) cls.from_pretrained = from_pretrained @@ -189,7 +201,9 @@ def from_quantized( f"{model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model." ) - check_versions(model_id_or_path, cls.require_pkgs_version) + check_versions(cls, cls.require_pkgs_version) + + model_local_path = get_model_local_path(model_id_or_path, **kwargs) # Parameters related to loading from Hugging Face Hub cache_dir = kwargs.pop("cache_dir", None) @@ -217,7 +231,7 @@ def from_quantized( # == step1: prepare configs and file names == # config: PretrainedConfig = AutoConfig.from_pretrained( - model_id_or_path, + model_local_path, trust_remote_code=trust_remote_code, **cached_file_kwargs, ) @@ -231,7 +245,7 @@ def from_quantized( if config.model_type not in SUPPORTED_MODELS: raise TypeError(f"{config.model_type} isn't supported yet.") - quantize_config = QuantizeConfig.from_pretrained(model_id_or_path, **cached_file_kwargs, **kwargs) + quantize_config = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) if backend == BACKEND.VLLM or backend == BACKEND.SGLANG: if quantize_config.format != FORMAT.GPTQ: @@ -240,7 +254,7 @@ def from_quantized( from ..utils.vllm import load_model_by_vllm, vllm_generate model = load_model_by_vllm( - model=model_id_or_path, + model=model_local_path, trust_remote_code=trust_remote_code, **kwargs, ) @@ -253,7 +267,7 @@ def from_quantized( from ..utils.sglang import load_model_by_sglang, sglang_generate model, hf_config = load_model_by_sglang( - model=model_id_or_path, + model=model_local_path, trust_remote_code=trust_remote_code, **kwargs, ) @@ -264,6 +278,7 @@ def from_quantized( quantized=True, quantize_config=quantize_config, qlinear_kernel=None, + model_local_path=model_local_path, ) if quantize_config.format == FORMAT.MARLIN: @@ -299,11 +314,11 @@ def from_quantized( extensions = [".safetensors"] - model_id_or_path = str(model_id_or_path) + model_local_path = str(model_local_path) # Retrieve (and if necessary download) the quantized checkpoint(s). is_sharded, resolved_archive_file, true_model_basename = get_checkpoints( - model_id_or_path=model_id_or_path, + model_id_or_path=model_local_path, extensions=extensions, possible_model_basenames=possible_model_basenames, **cached_file_kwargs, @@ -529,7 +544,7 @@ def skip(*args, **kwargs): qlinear_kernel=qlinear_kernel, load_quantized_model=True, trust_remote_code=trust_remote_code, - model_id_or_path=model_id_or_path, + model_local_path=model_local_path, ) cls.from_quantized = from_quantized diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 50ae49ec..9de3f18a 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -84,7 +84,7 @@ def save_quantized( w.writerows([[entry.get(QUANT_LOG_LAYER), entry.get(QUANT_LOG_MODULE), entry.get(QUANT_LOG_LOSS), entry.get(QUANT_LOG_DAMP), entry.get(QUANT_LOG_TIME)] for entry in self.quant_log]) - pre_quantized_size_mb = get_model_files_size(self.model_id_or_path) + pre_quantized_size_mb = get_model_files_size(self.model_local_path) pre_quantized_size_gb = pre_quantized_size_mb / 1024 quantizers = [f"{META_QUANTIZER_GPTQMODEL}:{__version__}"] @@ -171,7 +171,7 @@ def save_quantized( else: model = self.get_model_with_quantize( quantize_config=quantize_config, - model_id_or_path=self.model_id_or_path, + model_id_or_path=self.model_local_path, ) model.to(CPU) @@ -311,7 +311,7 @@ def save_quantized( # need to copy .py files for model/tokenizers not yet merged to HF transformers if self.trust_remote_code: - copy_py_files(save_dir, model_id_or_path=self.model_id_or_path) + copy_py_files(save_dir, model_id_or_path=self.model_local_path) cls.save_quantized = save_quantized diff --git a/tests/models/model_test.py b/tests/models/model_test.py index e4f16eb9..10e2d44b 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -187,7 +187,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del try: with tempfile.TemporaryDirectory() as tmp_dir: if self.USE_VLLM: - model_args = f"pretrained={model.model_id_or_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}" + model_args = f"pretrained={model.model_local_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}" else: model_args = "" results = lm_eval( @@ -216,8 +216,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del if metric != 'alias' and 'stderr' not in metric } print(task_results) - if delete_quantized_model and model.model_id_or_path.startswith("/tmp") and os.path.exists(model.model_id_or_path): - shutil.rmtree(model.model_id_or_path) + if delete_quantized_model and model.model_local_path.startswith("/tmp") and os.path.exists(model.model_local_path): + shutil.rmtree(model.model_local_path) return task_results except BaseException as e: if isinstance(e, torch.OutOfMemoryError):