Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into auto-device-dtype-kernel-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LRL-ModelCloud committed Dec 23, 2024
2 parents 8a637aa + 21949f8 commit 1cd77d3
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ env:
MAX_JOBS: 8
RUNNER: 10.0.14.248
TRANSFORMERS_DIFF_TESTS: "models/test_internlm,models/test_internlm2_5,models/test_xverse"
TORCH_2_5_TESTS: "test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba"
TORCH_2_5_TESTS: "test_evalplus,test_perplexity,test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba"
IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral"
GPTQMODEL_FORCE_BUILD: 1
repo: ${{ github.event.inputs.repo || github.repository }}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
</p>

## News
* 12/23/2024 [1.5.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.5.0): Multi-modal (image-to-text) optimized quantization support added for Qwen 2-VL and Ovis 1.6-VL. Previously image-to-text model quantization was not using image calibration data and post-quant result was less than optimal. 1.5.0 is the first release to release a stable path for multi-modal quantization: note only text layers are quantized.
* 12/19/2024 [1.4.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.5): Windows 11 support added/validated. Ovis VL model support with image dataset calibration. Fixed `dynamic` loading. Reduced quantization vram usage.
* 12/15/2024 [1.4.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.2): MacOS `gpu` (Metal) and `cpu` (M+) support added/validated for inference and quantization. Cohere 2 model support added.
* 12/13/2024 [1.4.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.1): Added Qwen2-VL model support. `mse` quantization control exposed in `QuantizeConfig`. Monkey patch `patch_vllm()` and `patch_hf()` api added to allow Transformers/Optimum/PEFT and vLLM to correctly loaded GPTQModel quantized models while upstream PRs are in pending status.
Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
trust_remote_code: bool = False,
model_id_or_path: str = None,
model_local_path: str = None,
):
super().__init__()

Expand All @@ -110,7 +110,7 @@ def __init__(
# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_id_or_path = model_id_or_path
self.model_local_path = model_local_path
# stores all per-layer quant stats such as avg loss and processing time
self.quant_log = []

Expand Down Expand Up @@ -440,7 +440,7 @@ def store_input_hook(_, args, kwargs):
example[k] = move_to(v, cur_layer_device)
try:
if is_ovis:
self.generate(inputs=example.pop("input_ids"), **example)
self.generate(inputs=example.pop("input_ids"), max_new_tokens=1024, **example)
else:
self.model(**example)
except ValueError:
Expand Down Expand Up @@ -731,7 +731,7 @@ def save(
):
extra_json_file_names = ["preprocessor_config.json", "chat_template.json"]
for name in extra_json_file_names:
json_path = os.path.join(self.model_id_or_path, name)
json_path = os.path.join(self.model_local_path, name)
if os.path.exists(json_path):
os.makedirs(save_dir, exist_ok=True)

Expand Down
19 changes: 17 additions & 2 deletions gptqmodel/models/definitions/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os.path
import shutil
from typing import Dict, Optional
from PIL import Image

from transformers import AutoModelForVision2Seq, Qwen2VLProcessor
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer

from ..base import BaseGPTQModel
from ...utils.calibration import batched
Expand Down Expand Up @@ -82,7 +84,20 @@ def prepare_dataset(
calibration_dataset,
batch_size: int = 1,
tokenizer=None, ):
processor = Qwen2VLProcessor.from_pretrained(self.model_id_or_path)
import tempfile
import json

if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.model_local_path)

with tempfile.TemporaryDirectory() as tmp_dir:
chat_template_file = os.path.join(self.model_local_path, "chat_template.json")
if os.path.exists(chat_template_file):
shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json"))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "preprocessor_config.json"), "w") as f:
f.write(json.dumps(self.quant_override_files["preprocessor_config.json"]))
processor = AutoProcessor.from_pretrained(tmp_dir)
calib_data = []
for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset):
text = processor.apply_chat_template(
Expand Down
48 changes: 32 additions & 16 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -36,6 +37,7 @@
verify_sharded_model_hashes,
)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device
from huggingface_hub import snapshot_download


logger = setup_logger()
Expand Down Expand Up @@ -73,17 +75,25 @@ def compare_versions(installed_version, required_version, operator):
raise ValueError(f"Unsupported operator: {operator}")


def check_versions(model_id_or_path: str, requirements: List[str]):
def check_versions(model_class, requirements: List[str]):
if requirements is None:
return
for req in requirements:
pkg, operator, version_required = parse_requirement(req)
try:
installed_version = version(pkg)
if not compare_versions(installed_version, version_required, operator):
raise ValueError(f"{model_id_or_path} requires version {req}, but current {pkg} version is {installed_version} ")
raise ValueError(f"{model_class} requires version {req}, but current {pkg} version is {installed_version} ")
except PackageNotFoundError:
raise ValueError(f"{model_id_or_path} requires version {req}, but {pkg} not installed.")
raise ValueError(f"{model_class} requires version {req}, but {pkg} not installed.")


def get_model_local_path(pretrained_model_id_or_path, **kwargs):
is_local = os.path.isdir(pretrained_model_id_or_path)
if is_local:
return pretrained_model_id_or_path
else:
return snapshot_download(pretrained_model_id_or_path, **kwargs)


def ModelLoader(cls):
Expand Down Expand Up @@ -117,7 +127,9 @@ def from_pretrained(
f"{pretrained_model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(pretrained_model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs)

def skip(*args, **kwargs):
pass
Expand All @@ -126,7 +138,9 @@ def skip(*args, **kwargs):
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

config = AutoConfig.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
model_init_kwargs["trust_remote_code"] = trust_remote_code

config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs)

# normalize and auto select quantization device is not passed
if quantize_config.device is None:
Expand All @@ -139,15 +153,14 @@ def skip(*args, **kwargs):
torch_dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False)

# enforce some values despite user specified
model_init_kwargs["trust_remote_code"] = trust_remote_code
# non-quantized models are always loaded into cpu
model_init_kwargs["device_map"] = cpu_device_map
model_init_kwargs["torch_dtype"] = torch_dtype

if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
model = cls.loader.from_pretrained(model_local_path, **model_init_kwargs)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"]
Expand All @@ -166,7 +179,7 @@ def skip(*args, **kwargs):
quantized=False,
quantize_config=quantize_config,
trust_remote_code=trust_remote_code,
model_id_or_path=pretrained_model_id_or_path
model_local_path=model_local_path,
)

cls.from_pretrained = from_pretrained
Expand Down Expand Up @@ -207,7 +220,9 @@ def from_quantized(
f"{model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(model_id_or_path, **kwargs)

# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
Expand Down Expand Up @@ -235,7 +250,7 @@ def from_quantized(

# == step1: prepare configs and file names == #
config: PretrainedConfig = AutoConfig.from_pretrained(
model_id_or_path,
model_local_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs,
)
Expand All @@ -247,7 +262,7 @@ def from_quantized(
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

quantize_config = QuantizeConfig.from_pretrained(model_id_or_path, **cached_file_kwargs, **kwargs)
quantize_config = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)

if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if quantize_config.format != FORMAT.GPTQ:
Expand All @@ -256,7 +271,7 @@ def from_quantized(
from ..utils.vllm import load_model_by_vllm, vllm_generate

model = load_model_by_vllm(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -269,7 +284,7 @@ def from_quantized(
from ..utils.sglang import load_model_by_sglang, sglang_generate

model, hf_config = load_model_by_sglang(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -280,6 +295,7 @@ def from_quantized(
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=None,
model_local_path=model_local_path,
)

if quantize_config.format == FORMAT.MARLIN:
Expand Down Expand Up @@ -315,11 +331,11 @@ def from_quantized(

extensions = [".safetensors"]

model_id_or_path = str(model_id_or_path)
model_local_path = str(model_local_path)

# Retrieve (and if necessary download) the quantized checkpoint(s).
is_sharded, resolved_archive_file, true_model_basename = get_checkpoints(
model_id_or_path=model_id_or_path,
model_id_or_path=model_local_path,
extensions=extensions,
possible_model_basenames=possible_model_basenames,
**cached_file_kwargs,
Expand Down Expand Up @@ -524,7 +540,7 @@ def skip(*args, **kwargs):
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
trust_remote_code=trust_remote_code,
model_id_or_path=model_id_or_path,
model_local_path=model_local_path,
)

cls.from_quantized = from_quantized
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def save_quantized(
w.writerows([[entry.get(QUANT_LOG_LAYER), entry.get(QUANT_LOG_MODULE), entry.get(QUANT_LOG_LOSS),
entry.get(QUANT_LOG_DAMP), entry.get(QUANT_LOG_TIME)] for entry in self.quant_log])

pre_quantized_size_mb = get_model_files_size(self.model_id_or_path)
pre_quantized_size_mb = get_model_files_size(self.model_local_path)
pre_quantized_size_gb = pre_quantized_size_mb / 1024

quantizers = [f"{META_QUANTIZER_GPTQMODEL}:{__version__}"]
Expand Down Expand Up @@ -171,7 +171,7 @@ def save_quantized(
else:
model = self.get_model_with_quantize(
quantize_config=quantize_config,
model_id_or_path=self.model_id_or_path,
model_id_or_path=self.model_local_path,
)

model.to(CPU)
Expand Down Expand Up @@ -311,7 +311,7 @@ def save_quantized(

# need to copy .py files for model/tokenizers not yet merged to HF transformers
if self.trust_remote_code:
copy_py_files(save_dir, model_id_or_path=self.model_id_or_path)
copy_py_files(save_dir, model_id_or_path=self.model_local_path)

cls.save_quantized = save_quantized

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.6-dev"
__version__ = "1.5.0"
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST")

if TORCH_CUDA_ARCH_LIST:
arch_list = " ".join([arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0 or print(f"we do not support this compute arch: {arch}, skipped.") is not None])
if arch_list != TORCH_CUDA_ARCH_LIST:
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list
print(f"TORCH_CUDA_ARCH_LIST has been updated to '{arch_list}'")

version_vars = {}
exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars)
gptqmodel_version = version_vars['version']
Expand Down Expand Up @@ -109,6 +115,7 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
if got_cuda_between_v6_and_v8:
FORCE_BUILD = True


if BUILD_CUDA_EXT:
if CUDA_RELEASE == "1":
common_setup_kwargs["version"] += f"+{get_version_tag(True)}"
Expand Down
6 changes: 3 additions & 3 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
try:
with tempfile.TemporaryDirectory() as tmp_dir:
if self.USE_VLLM:
model_args = f"pretrained={model.model_id_or_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
model_args = f"pretrained={model.model_local_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
else:
model_args = ""
results = lm_eval(
Expand Down Expand Up @@ -216,8 +216,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
if metric != 'alias' and 'stderr' not in metric
}
print(task_results)
if delete_quantized_model and model.model_id_or_path.startswith("/tmp") and os.path.exists(model.model_id_or_path):
shutil.rmtree(model.model_id_or_path)
if delete_quantized_model and model.model_local_path.startswith("/tmp") and os.path.exists(model.model_local_path):
shutil.rmtree(model.model_local_path)
return task_results
except BaseException as e:
if isinstance(e, torch.OutOfMemoryError):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/ovis/image_to_test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def get_calib_dataset(model):
return prepare_dataset(format_ovis_dataset, n_sample=20)

if isinstance(model, Qwen2VLGPTQ):
return prepare_dataset(format_qwen2_vl_dataset, n_sample=20)
return prepare_dataset(format_qwen2_vl_dataset, n_sample=1)

raise NotImplementedError(f"Unsupported MODEL: {model.__class__}")
2 changes: 1 addition & 1 deletion tests/models/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_qwen2_vl(self):
messages, tokenize=False, add_generation_prompt=True
)

image_inputs, video_inputs = Qwen2VLGPTQ.process_vision_info(messages)
image_inputs = Qwen2VLGPTQ.process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
Expand Down

0 comments on commit 1cd77d3

Please sign in to comment.