Skip to content

Commit

Permalink
[FIX] wrong filepath was used when model_id_or_path was hugging model…
Browse files Browse the repository at this point in the history
… id (#956)

* fix issue: the wrong filepath was used when the model_id_or_path was a hugging model id

* cleanup

* BaseModel removed "model_id_or_path"
  • Loading branch information
ZX-ModelCloud authored Dec 23, 2024
1 parent 883a52a commit 21949f8
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 26 deletions.
6 changes: 3 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
trust_remote_code: bool = False,
model_id_or_path: str = None,
model_local_path: str = None,
):
super().__init__()

Expand All @@ -114,7 +114,7 @@ def __init__(
# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_id_or_path = model_id_or_path
self.model_local_path = model_local_path
# stores all per-layer quant stats such as avg loss and processing time
self.quant_log = []

Expand Down Expand Up @@ -774,7 +774,7 @@ def save(
):
extra_json_file_names = ["preprocessor_config.json", "chat_template.json"]
for name in extra_json_file_names:
json_path = os.path.join(self.model_id_or_path, name)
json_path = os.path.join(self.model_local_path, name)
if os.path.exists(json_path):
os.makedirs(save_dir, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/definitions/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def prepare_dataset(
import json

if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
tokenizer = AutoTokenizer.from_pretrained(self.model_local_path)

with tempfile.TemporaryDirectory() as tmp_dir:
chat_template_file = os.path.join(self.model_id_or_path, "chat_template.json")
chat_template_file = os.path.join(self.model_local_path, "chat_template.json")
if os.path.exists(chat_template_file):
shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json"))
tokenizer.save_pretrained(tmp_dir)
Expand Down
45 changes: 30 additions & 15 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -36,6 +37,7 @@
verify_sharded_model_hashes,
)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device
from huggingface_hub import snapshot_download


logger = setup_logger()
Expand Down Expand Up @@ -73,17 +75,25 @@ def compare_versions(installed_version, required_version, operator):
raise ValueError(f"Unsupported operator: {operator}")


def check_versions(model_id_or_path: str, requirements: List[str]):
def check_versions(model_class, requirements: List[str]):
if requirements is None:
return
for req in requirements:
pkg, operator, version_required = parse_requirement(req)
try:
installed_version = version(pkg)
if not compare_versions(installed_version, version_required, operator):
raise ValueError(f"{model_id_or_path} requires version {req}, but current {pkg} version is {installed_version} ")
raise ValueError(f"{model_class} requires version {req}, but current {pkg} version is {installed_version} ")
except PackageNotFoundError:
raise ValueError(f"{model_id_or_path} requires version {req}, but {pkg} not installed.")
raise ValueError(f"{model_class} requires version {req}, but {pkg} not installed.")


def get_model_local_path(pretrained_model_id_or_path, **kwargs):
is_local = os.path.isdir(pretrained_model_id_or_path)
if is_local:
return pretrained_model_id_or_path
else:
return snapshot_download(pretrained_model_id_or_path, **kwargs)


def ModelLoader(cls):
Expand All @@ -106,7 +116,9 @@ def from_pretrained(
f"{pretrained_model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(pretrained_model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs)

def skip(*args, **kwargs):
pass
Expand All @@ -117,7 +129,7 @@ def skip(*args, **kwargs):

model_init_kwargs["trust_remote_code"] = trust_remote_code

config = AutoConfig.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs)

if torch_dtype is None or torch_dtype == "auto":
torch_dtype = auto_dtype_from_config(config)
Expand All @@ -130,7 +142,7 @@ def skip(*args, **kwargs):
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
model = cls.loader.from_pretrained(model_local_path, **model_init_kwargs)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"]
Expand All @@ -149,7 +161,7 @@ def skip(*args, **kwargs):
quantized=False,
quantize_config=quantize_config,
trust_remote_code=trust_remote_code,
model_id_or_path=pretrained_model_id_or_path
model_local_path=model_local_path,
)

cls.from_pretrained = from_pretrained
Expand Down Expand Up @@ -189,7 +201,9 @@ def from_quantized(
f"{model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(model_id_or_path, **kwargs)

# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
Expand Down Expand Up @@ -217,7 +231,7 @@ def from_quantized(

# == step1: prepare configs and file names == #
config: PretrainedConfig = AutoConfig.from_pretrained(
model_id_or_path,
model_local_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs,
)
Expand All @@ -231,7 +245,7 @@ def from_quantized(
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

quantize_config = QuantizeConfig.from_pretrained(model_id_or_path, **cached_file_kwargs, **kwargs)
quantize_config = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)

if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if quantize_config.format != FORMAT.GPTQ:
Expand All @@ -240,7 +254,7 @@ def from_quantized(
from ..utils.vllm import load_model_by_vllm, vllm_generate

model = load_model_by_vllm(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -253,7 +267,7 @@ def from_quantized(
from ..utils.sglang import load_model_by_sglang, sglang_generate

model, hf_config = load_model_by_sglang(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -264,6 +278,7 @@ def from_quantized(
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=None,
model_local_path=model_local_path,
)

if quantize_config.format == FORMAT.MARLIN:
Expand Down Expand Up @@ -299,11 +314,11 @@ def from_quantized(

extensions = [".safetensors"]

model_id_or_path = str(model_id_or_path)
model_local_path = str(model_local_path)

# Retrieve (and if necessary download) the quantized checkpoint(s).
is_sharded, resolved_archive_file, true_model_basename = get_checkpoints(
model_id_or_path=model_id_or_path,
model_id_or_path=model_local_path,
extensions=extensions,
possible_model_basenames=possible_model_basenames,
**cached_file_kwargs,
Expand Down Expand Up @@ -529,7 +544,7 @@ def skip(*args, **kwargs):
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
trust_remote_code=trust_remote_code,
model_id_or_path=model_id_or_path,
model_local_path=model_local_path,
)

cls.from_quantized = from_quantized
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def save_quantized(
w.writerows([[entry.get(QUANT_LOG_LAYER), entry.get(QUANT_LOG_MODULE), entry.get(QUANT_LOG_LOSS),
entry.get(QUANT_LOG_DAMP), entry.get(QUANT_LOG_TIME)] for entry in self.quant_log])

pre_quantized_size_mb = get_model_files_size(self.model_id_or_path)
pre_quantized_size_mb = get_model_files_size(self.model_local_path)
pre_quantized_size_gb = pre_quantized_size_mb / 1024

quantizers = [f"{META_QUANTIZER_GPTQMODEL}:{__version__}"]
Expand Down Expand Up @@ -171,7 +171,7 @@ def save_quantized(
else:
model = self.get_model_with_quantize(
quantize_config=quantize_config,
model_id_or_path=self.model_id_or_path,
model_id_or_path=self.model_local_path,
)

model.to(CPU)
Expand Down Expand Up @@ -311,7 +311,7 @@ def save_quantized(

# need to copy .py files for model/tokenizers not yet merged to HF transformers
if self.trust_remote_code:
copy_py_files(save_dir, model_id_or_path=self.model_id_or_path)
copy_py_files(save_dir, model_id_or_path=self.model_local_path)

cls.save_quantized = save_quantized

Expand Down
6 changes: 3 additions & 3 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
try:
with tempfile.TemporaryDirectory() as tmp_dir:
if self.USE_VLLM:
model_args = f"pretrained={model.model_id_or_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
model_args = f"pretrained={model.model_local_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
else:
model_args = ""
results = lm_eval(
Expand Down Expand Up @@ -216,8 +216,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
if metric != 'alias' and 'stderr' not in metric
}
print(task_results)
if delete_quantized_model and model.model_id_or_path.startswith("/tmp") and os.path.exists(model.model_id_or_path):
shutil.rmtree(model.model_id_or_path)
if delete_quantized_model and model.model_local_path.startswith("/tmp") and os.path.exists(model.model_local_path):
shutil.rmtree(model.model_local_path)
return task_results
except BaseException as e:
if isinstance(e, torch.OutOfMemoryError):
Expand Down

0 comments on commit 21949f8

Please sign in to comment.