Skip to content

Commit

Permalink
Add QuantizeConfig.device and use. (#950)
Browse files Browse the repository at this point in the history
* normalize device + device_map

* normalize device+device+map+dtype in from_pretrained()

* disallow passing of device/device_map in pretrained(). add `device` to QuantizeConfig.

* if user pass device+device_map and quantizeconfig.device is none, use...else quantizeconfig.device, fall back is auto select

* auto-device logic should not be here

* reduce reliance on accelerate

* remove bad device override

* fix dev not define

* cleanup

* already check device when select_quant_linear

* fix marlin post_init

---------

Co-authored-by: LRL-ModelCloud <lrl@lbx.dev>
  • Loading branch information
Qubitium and LRL-ModelCloud authored Dec 24, 2024
1 parent a235687 commit 55f9d72
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 176 deletions.
9 changes: 3 additions & 6 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def load(
is_quantized = True
break

# TODO fix me...unify device + device_map auto logic
if not device and not device_map or device_map == "auto":
device = get_best_device(backend=backend)

if is_quantized:
return cls.from_quantized(
model_id_or_path=model_id_or_path,
Expand All @@ -186,6 +182,8 @@ def load(
return cls.from_pretrained(
model_id_or_path=model_id_or_path,
quantize_config=quantize_config,
device_map=device_map,
device=device,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand Down Expand Up @@ -230,14 +228,13 @@ def from_quantized(
**kwargs,
) -> BaseGPTQModel:
model_type = check_and_get_model_type(model_id_or_path, trust_remote_code)
quant_func = MODEL_MAP[model_type].from_quantized

if backend == BACKEND.AUTO:
if not torch.cuda.is_available() and HAS_IPEX:
logger.warning("No cuda found, use IPEX backend")
backend = BACKEND.IPEX

return quant_func(
return MODEL_MAP[model_type].from_quantized(
model_id_or_path=model_id_or_path,
device_map=device_map,
device=device,
Expand Down
65 changes: 11 additions & 54 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import shutil
from typing import Dict, List, Optional, Union, Any

import accelerate
import torch
import torch.nn as nn
from accelerate.hooks import remove_hook_from_module
from packaging import version
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils

Expand All @@ -25,12 +23,10 @@
find_layers,
get_device,
get_module_by_name_prefix,
get_module_by_name_suffix,
get_moe_layer_modules,
move_to,
nested_move_to,
pack_model,
simple_dispatch_model,
MODALITY,
)
from ..utils.progress import ProgressBar
Expand Down Expand Up @@ -106,7 +102,7 @@ def __init__(
super().__init__()

self.model = model
self._quantized = quantized
self.quantized = quantized
self.load_quantized_model = load_quantized_model
self.quantize_config = quantize_config
self.config = self.model.config
Expand All @@ -122,14 +118,6 @@ def __init__(
if self.require_monkeypatch:
self.monkey_patch()

@property
def quantized(self):
return self._quantized

@property
def hf_device_map(self):
return getattr(self.model, "hf_device_map", None)

def prepare_dataset(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
Expand Down Expand Up @@ -213,10 +201,6 @@ def quantize(
f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}"
)

# TODO FIX ME! not best device, but if user pass device/device_map
# use it! else, best-device
best_device = get_best_device(backend)

if backend == BACKEND.IPEX:
self.quantize_config.format = FORMAT.IPEX

Expand Down Expand Up @@ -254,7 +238,7 @@ def quantize(
desc_act=self.quantize_config.desc_act,
sym=self.quantize_config.sym,
backend=backend,
device=DEVICE(best_device.type),
device=DEVICE(self.quantize_config.device),
pack=True,
format=self.quantize_config.format,
)
Expand All @@ -271,16 +255,6 @@ def quantize(
if BITBLAS_AVAILABLE is False:
raise ValueError(BITBLAS_INSTALL_HINT)


device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu" and best_device != CPU:
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, best_device)

calibration_dataset = self.prepare_dataset(calibration_dataset, batch_size, tokenizer,)

# Calculate the average length of the average input_ids
Expand Down Expand Up @@ -389,7 +363,7 @@ def collate_batch(batch):
)

self.model = model
self._quantized = True
self.quantized = True
return

forward_pass_use_cache = self.model.config.use_cache if hasattr(self.model.config, "use_cache") else False
Expand Down Expand Up @@ -437,7 +411,7 @@ def store_input_hook(_, args, kwargs):
raise ValueError

# move layer to target device
layers[0] = layers[0].to(best_device)
layers[0] = layers[0].to(self.quantize_config.device)

ori_outside_layer_module_devices = {}
for module_name in self.base_modules:
Expand Down Expand Up @@ -528,8 +502,8 @@ def store_input_hook(_, args, kwargs):
gpu_memorys.append(gpu_memory)
cpu_memorys.append(cpu_memory)

if get_device(layer) == CPU and best_device != CPU:
move_to(layer, best_device)
if get_device(layer) == CPU and self.quantize_config.device != CPU:
move_to(layer, self.quantize_config.device)

cur_layer_device = get_device(layer)
full = find_layers(layer)
Expand Down Expand Up @@ -722,42 +696,25 @@ def tmp(_, inp, out):
parallel_packing=self.quantize_config.parallel_packing,
)

if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
self.model = simple_dispatch_model(self.model, device_map)
self.model.config.use_cache = forward_pass_use_cache

self._quantized = True

self.quantized = True
torch_empty_cache()

return self.quant_log

@property
def device(self):
if not self.hf_device_map:
if hasattr(self.model, "device"):
return self.model.device
elif hasattr(self.model, "llm_engine"):
return self.model.llm_engine.device_config.device_type
else:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = [d for d in self.hf_device_map.values() if d not in {"disk"}][0]
return torch.device(device)

def to(self, device: Union[str, torch.device]):
if hasattr(self.model, "to"):
self.model.to(device)
self.model = self.model.to(device)
return self
else:
logger.warning(f"{self.model.__class__.__name__} does not support the to() method")
return self
raise f"{self.model.__class__.__name__} does not support the to() method"

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def generate(self, **kwargs):
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
with torch.inference_mode():
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
Expand Down
66 changes: 31 additions & 35 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from ..quantization import QuantizeConfig
from ..quantization.config import FORMAT, FORMAT_FIELD_JSON, MIN_VERSION_WITH_V2
from ..utils.backend import BACKEND
from ..utils.importer import select_device, select_quant_linear
from ..utils.importer import auto_select_device, select_quant_linear, normalize_device_device_map
from ..utils.logger import setup_logger
from ..utils.marlin import (
_validate_marlin_compatibility,
_validate_marlin_device_support,
prepare_model_for_marlin_load,
)
from ..utils.model import (
auto_dtype_from_config,
auto_dtype,
convert_gptq_v1_to_v2_format,
find_layers,
get_checkpoints,
Expand Down Expand Up @@ -104,9 +104,20 @@ def from_pretrained(
quantize_config: QuantizeConfig,
trust_remote_code: bool = False,
torch_dtype: [str | torch.dtype] = "auto",
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
device: Optional[Union[str, int]] = None,
**model_init_kwargs,
):
"""load un-quantized pretrained model to cpu"""
# non-quantized models are always loaded into cpu
cpu_device_map = {"": "cpu"}

if quantize_config is None or not isinstance(quantize_config, QuantizeConfig):
raise AttributeError("`quantize_config` must be passed and be an instance of QuantizeConfig.")

if quantize_config.device is not None:
if device is not None or device_map is not None:
raise AttributeError("Passing device and device_map is not allowed when QuantizeConfig.device is set. Non-quantized model is always loaded as cpu. Please set QuantizeConfig.device for accelerator used in quantization or do not set for auto-selection.")

if quantize_config.desc_act not in cls.supports_desc_act:
raise ValueError(f"{cls} only supports desc_act={cls.supports_desc_act}, "
f"but quantize_config.desc_act is {quantize_config.desc_act}.")
Expand All @@ -131,12 +142,19 @@ def skip(*args, **kwargs):

config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs)

if torch_dtype is None or torch_dtype == "auto":
torch_dtype = auto_dtype_from_config(config)
elif not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"torch_dtype value of `{torch_dtype}` is not a torch.dtype instance.")
# normalize and auto select quantization device is not passed
if quantize_config.device is None:
quantize_config.device = auto_select_device(None, None)
else:
quantize_config.device = normalize_device(quantize_config.device)

if torch_dtype is None or torch_dtype == "auto" or not isinstance(torch_dtype, torch.dtype):
# TODO FIX ME for `dynamic`, non-quantized modules should be in native type
torch_dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False)

# enforce some values despite user specified
# non-quantized models are always loaded into cpu
model_init_kwargs["device_map"] = cpu_device_map
model_init_kwargs["torch_dtype"] = torch_dtype

if config.model_type not in SUPPORTED_MODELS:
Expand Down Expand Up @@ -178,11 +196,12 @@ def from_quantized(
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
if device is not None:
device = normalize_device(device)
# normalized device + device_map into single device
device = normalize_device_device_map(device, device_map)

# TODO need to normalize backend and others in a unified api
device = select_device(device, device_map, backend)
device = auto_select_device(device, backend)
device_map = {"":device}

if backend == BACKEND.VLLM:
import os
Expand Down Expand Up @@ -236,11 +255,9 @@ def from_quantized(
**cached_file_kwargs,
)

if torch_dtype is None or torch_dtype == "auto":
if torch_dtype is None or torch_dtype == "auto" or not isinstance(torch_dtype, torch.dtype) :
# TODO FIX ME for `dynamic`, non-quantized modules should be in native type
torch_dtype = auto_dtype_from_config(config=config, device=device, device_map=device_map, quant_inference=True)
elif not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"torch_dtype value of `{torch_dtype}` is not a torch.dtype instance.")
torch_dtype = auto_dtype(config=config, device=device, quant_inference=True)

if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
Expand Down Expand Up @@ -396,27 +413,6 @@ def skip(*args, **kwargs):
quantize_config.runtime_format = FORMAT.IPEX
model.tie_weights()

# == step3: load checkpoint and dispatch == #
if isinstance(device_map, str) and device_map not in [
"auto",
"balanced",
"balanced_low_0",
"sequential",
]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'."
)

if not isinstance(device_map, dict):
if device is not None:
device_map = {"": 0 if device in [DEVICE.CUDA, DEVICE.XPU, DEVICE.MPS] else DEVICE.CPU}
else:
device_map = accelerate.infer_auto_device_map(
model,
no_split_module_classes=[cls.layer_type] if isinstance(cls.layer_type, str) else cls.layer_type,
)

load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ and backend != BACKEND.IPEX:
Expand Down
11 changes: 5 additions & 6 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ...models._const import DEVICE, PLATFORM, normalize_device
from ...models._const import DEVICE, PLATFORM


class BaseQuantLinear(nn.Module):
Expand Down Expand Up @@ -158,11 +157,11 @@ def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynami
return True, None

@classmethod
def validate_device(cls, device: str|DEVICE|int|torch.device):
dev = normalize_device(device)
def validate_device(cls, device: DEVICE):
assert isinstance(device, DEVICE)

if dev not in cls.SUPPORTS_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}`: actual device = `{dev}`")
if device not in cls.SUPPORTS_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}`: actual device = `{device}`")

# override me
def post_init(self):
Expand Down
1 change: 0 additions & 1 deletion gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def reset_parameters(self):
self.q_params = None

def post_init(self):
self.validate_device(self.qweight.device.type)
# eliminate runtime overhead like exllama state
param_list = [self.qweight, self.scales, self.zeros]
if self.bitblas_matmul.config.with_bias:
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
return cls._validate(**args)

def post_init(self):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/nn_modules/qlinear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
return cls._validate(**args)

def post_init(self, temp_dq):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
return cls._validate(**args)

def post_init(self):
self.validate_device(self.qweight.device.type)
pass

def init_ipex_linear(self, x: torch.Tensor):
if not self.training and HAS_IPEX and not x.requires_grad:
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:

def post_init(self):
device = self.qweight.device
self.validate_device(device.type)

# Allocate marlin workspace
self.workspace = marlin_make_workspace(
self.outfeatures, device)
Expand Down
Loading

0 comments on commit 55f9d72

Please sign in to comment.