Skip to content

Commit

Permalink
gptqmodel
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Nov 29, 2024
1 parent 1cce05b commit 9bb7694
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
74 changes: 36 additions & 38 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import json
import os
import importlib
from enum import Enum
from logging import getLogger
from packaging import version
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers.pytorch_utils import Conv1D
from transformers.utils.quantization_config import QuantizationMethod

from ..utils import is_accelerate_available, is_auto_gptq_available
from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available
from ..utils.modeling_utils import recurse_getattr
from .constants import GPTQ_CONFIG
from .data import get_dataset, prepare_dataset
Expand All @@ -43,9 +43,16 @@

if is_auto_gptq_available():
from auto_gptq import exllama_set_max_input_length
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as select_quant_linear


if is_gptqmodel_available():
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import select_quant_linear
from gptqmodel.utils.model import gptqmodel_post_init as gptq_post_init

logger = getLogger(__name__)

Expand Down Expand Up @@ -255,13 +262,8 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
name (`str`, defaults to `""`):
To keep track of the name of the current module
"""
QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
QuantLinear = select_quant_linear(
bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym
)
if isinstance(module, QuantLinear):
return
Expand All @@ -281,20 +283,16 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
in_features = layer.weight.shape[0]
out_features = layer.weight.shape[1]
bias = layer.bias is not None
if not (self.desc_act) or self.group_size == -1:
new_layer = QuantLinear(
self.bits,
self.group_size,
in_features,
out_features,
bias,
use_cuda_fp16=self.use_cuda_fp16,
weight_dtype=layer.weight.dtype,
)
else:
new_layer = QuantLinear(
self.bits, self.group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
)
new_layer = QuantLinear(
self.bits,
self.group_size,
self.desc_act,
self.sym,
in_features,
out_features,
bias,
weight_dtype=layer.weight.dtype,
)
new_layer.device = device
setattr(module, attr, new_layer.to(device))
for name1, child in module.named_children():
Expand All @@ -320,10 +318,15 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
`nn.Module`: The quantized model
"""

if not is_auto_gptq_available():
raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`")
if not is_auto_gptq_available() and not is_gptqmodel_available():
raise RuntimeError(
"auto-gptq or gptqmodel is required in order to perform quantzation : `pip install auto-gptq`"
)

gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
gptq_supports_cpu = (
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or is_gptqmodel_available()
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed to quantize model.")

Expand Down Expand Up @@ -520,7 +523,7 @@ def tmp(_, input, output):
h.remove()
for name in subset_name_list:
logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...")
scale, zero, g_idx = gptq[name].fasterquant(
scale, zero, g_idx, _, _, _ = gptq[name].fasterquant(
percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act
)
quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = (
Expand Down Expand Up @@ -608,7 +611,7 @@ class StoreAttr(object):

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
model = autogptq_post_init(model, use_act_order=self.desc_act)
model = gptq_post_init(model, use_act_order=self.desc_act)
if (
self.desc_act
and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE)
Expand All @@ -631,13 +634,8 @@ def pack_model(
quantizers (`Dict[str,Tuple]`):
A mapping of the layer name and the data needed to pack the layer
"""
QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
QuantLinear = select_quant_linear(
bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym
)
logger.info("Packing model...")
layers = get_layers(model)
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_auto_gptq_available,
is_datasets_available,
is_diffusers_available,
is_gptqmodel_available,
is_onnx_available,
is_onnxruntime_available,
is_pydantic_available,
Expand Down
5 changes: 5 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_accelerate_available = _is_package_available("accelerate")
_diffusers_available = _is_package_available("diffusers")
_auto_gptq_available = _is_package_available("auto_gptq")
_gptqmodel_available = _is_package_available("gptqmodel")
_timm_available = _is_package_available("timm")
_sentence_transformers_available = _is_package_available("sentence_transformers")
_datasets_available = _is_package_available("datasets")
Expand Down Expand Up @@ -147,6 +148,10 @@ def is_auto_gptq_available():
)


def is_gptqmodel_available():
return _gptqmodel_available


@contextmanager
def check_if_pytorch_greater(target_version: str, message: str):
r"""
Expand Down

0 comments on commit 9bb7694

Please sign in to comment.