Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] wrong filepath was used when model_id_or_path was hugging model id #956

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
trust_remote_code: bool = False,
model_id_or_path: str = None,
model_local_path: str = None,
):
super().__init__()

Expand All @@ -114,7 +114,7 @@ def __init__(
# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_id_or_path = model_id_or_path
self.model_local_path = model_local_path
# stores all per-layer quant stats such as avg loss and processing time
self.quant_log = []

Expand Down Expand Up @@ -774,7 +774,7 @@ def save(
):
extra_json_file_names = ["preprocessor_config.json", "chat_template.json"]
for name in extra_json_file_names:
json_path = os.path.join(self.model_id_or_path, name)
json_path = os.path.join(self.model_local_path, name)
if os.path.exists(json_path):
os.makedirs(save_dir, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/definitions/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def prepare_dataset(
import json

if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
tokenizer = AutoTokenizer.from_pretrained(self.model_local_path)

with tempfile.TemporaryDirectory() as tmp_dir:
chat_template_file = os.path.join(self.model_id_or_path, "chat_template.json")
chat_template_file = os.path.join(self.model_local_path, "chat_template.json")
if os.path.exists(chat_template_file):
shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json"))
tokenizer.save_pretrained(tmp_dir)
Expand Down
45 changes: 30 additions & 15 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -36,6 +37,7 @@
verify_sharded_model_hashes,
)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device
from huggingface_hub import snapshot_download


logger = setup_logger()
Expand Down Expand Up @@ -73,17 +75,25 @@ def compare_versions(installed_version, required_version, operator):
raise ValueError(f"Unsupported operator: {operator}")


def check_versions(model_id_or_path: str, requirements: List[str]):
def check_versions(model_class, requirements: List[str]):
if requirements is None:
return
for req in requirements:
pkg, operator, version_required = parse_requirement(req)
try:
installed_version = version(pkg)
if not compare_versions(installed_version, version_required, operator):
raise ValueError(f"{model_id_or_path} requires version {req}, but current {pkg} version is {installed_version} ")
raise ValueError(f"{model_class} requires version {req}, but current {pkg} version is {installed_version} ")
except PackageNotFoundError:
raise ValueError(f"{model_id_or_path} requires version {req}, but {pkg} not installed.")
raise ValueError(f"{model_class} requires version {req}, but {pkg} not installed.")


def get_model_local_path(pretrained_model_id_or_path, **kwargs):
is_local = os.path.isdir(pretrained_model_id_or_path)
if is_local:
return pretrained_model_id_or_path
else:
return snapshot_download(pretrained_model_id_or_path, **kwargs)


def ModelLoader(cls):
Expand All @@ -106,7 +116,9 @@ def from_pretrained(
f"{pretrained_model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(pretrained_model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs)

def skip(*args, **kwargs):
pass
Expand All @@ -117,7 +129,7 @@ def skip(*args, **kwargs):

model_init_kwargs["trust_remote_code"] = trust_remote_code

config = AutoConfig.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs)

if torch_dtype is None or torch_dtype == "auto":
torch_dtype = auto_dtype_from_config(config)
Expand All @@ -130,7 +142,7 @@ def skip(*args, **kwargs):
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)
model = cls.loader.from_pretrained(model_local_path, **model_init_kwargs)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"]
Expand All @@ -149,7 +161,7 @@ def skip(*args, **kwargs):
quantized=False,
quantize_config=quantize_config,
trust_remote_code=trust_remote_code,
model_id_or_path=pretrained_model_id_or_path
model_local_path=model_local_path,
)

cls.from_pretrained = from_pretrained
Expand Down Expand Up @@ -189,7 +201,9 @@ def from_quantized(
f"{model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)

check_versions(model_id_or_path, cls.require_pkgs_version)
check_versions(cls, cls.require_pkgs_version)

model_local_path = get_model_local_path(model_id_or_path, **kwargs)

# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
Expand Down Expand Up @@ -217,7 +231,7 @@ def from_quantized(

# == step1: prepare configs and file names == #
config: PretrainedConfig = AutoConfig.from_pretrained(
model_id_or_path,
model_local_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs,
)
Expand All @@ -231,7 +245,7 @@ def from_quantized(
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

quantize_config = QuantizeConfig.from_pretrained(model_id_or_path, **cached_file_kwargs, **kwargs)
quantize_config = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)

if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if quantize_config.format != FORMAT.GPTQ:
Expand All @@ -240,7 +254,7 @@ def from_quantized(
from ..utils.vllm import load_model_by_vllm, vllm_generate

model = load_model_by_vllm(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -253,7 +267,7 @@ def from_quantized(
from ..utils.sglang import load_model_by_sglang, sglang_generate

model, hf_config = load_model_by_sglang(
model=model_id_or_path,
model=model_local_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand All @@ -264,6 +278,7 @@ def from_quantized(
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=None,
model_local_path=model_local_path,
)

if quantize_config.format == FORMAT.MARLIN:
Expand Down Expand Up @@ -299,11 +314,11 @@ def from_quantized(

extensions = [".safetensors"]

model_id_or_path = str(model_id_or_path)
model_local_path = str(model_local_path)

# Retrieve (and if necessary download) the quantized checkpoint(s).
is_sharded, resolved_archive_file, true_model_basename = get_checkpoints(
model_id_or_path=model_id_or_path,
model_id_or_path=model_local_path,
extensions=extensions,
possible_model_basenames=possible_model_basenames,
**cached_file_kwargs,
Expand Down Expand Up @@ -529,7 +544,7 @@ def skip(*args, **kwargs):
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
trust_remote_code=trust_remote_code,
model_id_or_path=model_id_or_path,
model_local_path=model_local_path,
)

cls.from_quantized = from_quantized
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def save_quantized(
w.writerows([[entry.get(QUANT_LOG_LAYER), entry.get(QUANT_LOG_MODULE), entry.get(QUANT_LOG_LOSS),
entry.get(QUANT_LOG_DAMP), entry.get(QUANT_LOG_TIME)] for entry in self.quant_log])

pre_quantized_size_mb = get_model_files_size(self.model_id_or_path)
pre_quantized_size_mb = get_model_files_size(self.model_local_path)
pre_quantized_size_gb = pre_quantized_size_mb / 1024

quantizers = [f"{META_QUANTIZER_GPTQMODEL}:{__version__}"]
Expand Down Expand Up @@ -171,7 +171,7 @@ def save_quantized(
else:
model = self.get_model_with_quantize(
quantize_config=quantize_config,
model_id_or_path=self.model_id_or_path,
model_id_or_path=self.model_local_path,
)

model.to(CPU)
Expand Down Expand Up @@ -311,7 +311,7 @@ def save_quantized(

# need to copy .py files for model/tokenizers not yet merged to HF transformers
if self.trust_remote_code:
copy_py_files(save_dir, model_id_or_path=self.model_id_or_path)
copy_py_files(save_dir, model_id_or_path=self.model_local_path)

cls.save_quantized = save_quantized

Expand Down
6 changes: 3 additions & 3 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
try:
with tempfile.TemporaryDirectory() as tmp_dir:
if self.USE_VLLM:
model_args = f"pretrained={model.model_id_or_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
model_args = f"pretrained={model.model_local_path},dtype=auto,gpu_memory_utilization=0.8,tensor_parallel_size=1,trust_remote_code={trust_remote_code},max_model_len={self.MODEL_MAX_LEN}"
else:
model_args = ""
results = lm_eval(
Expand Down Expand Up @@ -216,8 +216,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
if metric != 'alias' and 'stderr' not in metric
}
print(task_results)
if delete_quantized_model and model.model_id_or_path.startswith("/tmp") and os.path.exists(model.model_id_or_path):
shutil.rmtree(model.model_id_or_path)
if delete_quantized_model and model.model_local_path.startswith("/tmp") and os.path.exists(model.model_local_path):
shutil.rmtree(model.model_local_path)
return task_results
except BaseException as e:
if isinstance(e, torch.OutOfMemoryError):
Expand Down