Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto_gptq to lmdeploy lite #2372

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ def add_parser_auto_awq():
default=128,
help='Group size for weight quantization statistics')

@staticmethod
def add_parser_auto_gptq():
"""Add parser for auto_gptq command."""
parser = SubCliLite.subparsers.add_parser(
'auto_gptq',
formatter_class=DefaultsAndTypesHelpFormatter,
description=SubCliLite.auto_gptq.__doc__,
help=SubCliLite.auto_gptq.__doc__)
parser.set_defaults(run=SubCliLite.auto_gptq)
parser.add_argument('model',
type=str,
help='The path of model in hf format')
ArgumentHelper.revision(parser)
ArgumentHelper.work_dir(parser)
ArgumentHelper.calib_dataset(parser)
ArgumentHelper.calib_samples(parser)
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
parser.add_argument('--w-bits',
type=int,
default=4,
help='Bit number for weight quantization')
parser.add_argument(
'--w-group-size',
type=int,
default=128,
help='Group size for weight quantization statistics')

@staticmethod
def add_parser_calibrate():
"""Add parser for calibrate command."""
Expand Down Expand Up @@ -97,6 +125,13 @@ def auto_awq(args):
kwargs = convert_args(args)
auto_awq(**kwargs)

@staticmethod
def auto_gptq(args):
"""Perform weight quantization using AWQ algorithm."""
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
from lmdeploy.lite.apis.gptq import auto_gptq
kwargs = convert_args(args)
auto_gptq(**kwargs)

@staticmethod
def calibrate(args):
"""Perform calibration on a given dataset."""
Expand All @@ -115,5 +150,6 @@ def smooth_quant(args):
def add_parsers():
"""Add all parsers."""
SubCliLite.add_parser_auto_awq()
SubCliLite.add_parser_auto_gptq()
SubCliLite.add_parser_calibrate()
SubCliLite.add_parser_smooth_quant()
104 changes: 104 additions & 0 deletions lmdeploy/lite/apis/gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging

import torch
from transformers import AutoTokenizer

from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders


def auto_gptq(model: str,
work_dir: str = './work_dir',
w_bits: int = 4,
w_group_size: int = 128,
calib_dataset: str = 'ptb',
calib_samples: int = 128,
calib_seqlen: int = 2048,
batch_size: int = 1,
revision: str = None):
"""Perform weight quantization using AWQ algorithm.

Args:
model (str): The path of model in hf format.
work_dir (str): The working directory to save results.
calib_dataset (str): The calibration dataset name.
calib_samples (int): The number of samples for calibration.
batch_size (int): The batch size for running the calib samples.
Low GPU mem requires small batch_size. Large batch_size
reduces the calibration time while costs more VRAM.
calib_seqlen (int): The sequence length for calibration.
w_bits (int): Bit number for weight quantization.
w_group_size (int): Group size for weight quantization statistics.
search_scale (bool): Whether search scale ratio. Default to False,
which means only smooth quant with 0.5 ratio will be applied.
device (str): Device type of running.
revision (str): The specific model version to use. It can be a
branch name, a tag name, or a commit id. If unspecified,
will use the default version.
"""
try:
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
except Exception:
raise ImportError('To use auto_gptq, please install auto-gptq by '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

'pip install auto-gptq')
logging.basicConfig(
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
)
# support internlm2
from auto_gptq.modeling import GPTQ_CAUSAL_LM_MODEL_MAP
from auto_gptq.modeling._const import SUPPORTED_MODELS

from ..modeling.internlm2_gptq import InternLM2GPTQForCausalLM
SUPPORTED_MODELS.append('internlm2')
GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm2=InternLM2GPTQForCausalLM))

pretrained_model_dir = model
quantized_model_dir = work_dir

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
trust_remote_code=True)
print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset,
tokenizer,
nsamples=calib_samples,
seqlen=calib_seqlen)
all_data = [
data if isinstance(data, torch.Tensor) else data[0]
for data in calib_loader
]
attention_mask = [1] * calib_seqlen
examples = [
dict(input_ids=data.flatten().tolist(), attention_mask=attention_mask)
for data in all_data
]

quantize_config = BaseQuantizeConfig(
bits=w_bits, # quantize model to 4-bit
group_size=w_group_size, # it is recommended to set the value to 128
desc_act=False, # lmdeploy only supports False
sym=True, # lmdeploy only supports True
)

# load un-quantized model, by default,
# the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir,
quantize_config,
revision=revision,
trust_remote_code=True)

# quantize model, the examples should be list of dict whose keys
# can only be "input_ids" and "attention_mask"
model.quantize(examples, batch_size=batch_size)

# save quantized model
model.save_quantized(quantized_model_dir)

tokenizer.save_pretrained(quantized_model_dir)


if __name__ == '__main__':
import fire

fire.Fire(auto_gptq)
1 change: 1 addition & 0 deletions lmdeploy/lite/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
14 changes: 14 additions & 0 deletions lmdeploy/lite/modeling/internlm2_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from auto_gptq.modeling import BaseGPTQForCausalLM


class InternLM2GPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = 'InternLM2DecoderLayer'
layers_block_name = 'model.layers'
outside_layer_modules = ['model.tok_embeddings', 'model.norm']
inside_layer_modules = [
['attention.wqkv'],
['attention.wo'],
['feed_forward.w3', 'feed_forward.w1'],
['feed_forward.w2'],
]
Loading