Skip to content

Commit

Permalink
Merge pull request vllm-project#5 from ri938/more_improvements_awq
Browse files Browse the repository at this point in the history
More improvements awq
  • Loading branch information
ri938 authored Aug 16, 2023
2 parents a3ac858 + db4db0c commit 73db30f
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 83 deletions.
77 changes: 62 additions & 15 deletions vllm/awq_quantization/qmodule.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# adapted from llm-awq: https://github.com/mit-han-lab/llm-awq

import math
import torch
import torch.nn as nn

try:
import awq_inference_engine # with CUDA kernels
except ImportError as ex:
msg = "Unable to import awq_inference_engine: run setup.py to install CUDA kernels"
raise ImportError(msg)
raise ImportError(
"Unable to import awq_inference_engine: run setup.py"
" to install AWQ CUDA kernels")


class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)

def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)


class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
def __init__(
self,
w_bit,
group_size,
in_features,
out_features,
bias,
dev
):
super().__init__()

if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")

self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
Expand All @@ -37,23 +45,62 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0

self.register_buffer('qweight', torch.empty((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.empty((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.empty((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
qweight_buffer = torch.empty(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev
)
self.register_buffer("qweight", qweight_buffer)

qzeros_buffer = torch.empty(
(
in_features // self.group_size,
out_features // (32 // self.w_bit)
),
dtype=torch.int32,
device=dev
)
self.register_buffer("qzeros", qzeros_buffer)

scales_buffer = torch.empty(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev
)
self.register_buffer("scales", scales_buffer)

if bias:
self.register_buffer('bias', torch.empty((out_features), dtype=torch.float16, device=dev))
bias_buffer = torch.empty(
(out_features),
dtype=torch.float16,
device=dev
)
self.register_buffer("bias", bias_buffer)
else:
self.bias = None

@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)

out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8
)

out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
str_repr = "in_features={}, out_features={}, " \
"bias={}, w_bit={}, group_size={}"
return str_repr.format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size
)
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self._verify()

def _verify(self) -> None:
allowed_methods = ['awq']
allowed_methods = ["awq"]
if self.method not in allowed_methods:
raise ValueError(
f"Unknown quantization method ({self.method})"
Expand Down Expand Up @@ -118,7 +118,8 @@ def verify_with_parallel_config(
f"({pipeline_parallel_size}).")

if self.quantization_config and tensor_parallel_size > 1:
raise NotImplementedError("Quantization does not currently support tensor parallelism")
raise NotImplementedError(
"Quantization does not currently support tensor parallelism")

def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
quantization_config = QuantizationConfig(self.quantization) if self.quantization else None
if self.quantization is not None:
quantization_config = QuantizationConfig(self.quantization)
else:
quantization_config = None

model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.use_np_weights,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# The weights will be initialized as empty tensors.

if _supports_quantization(model_class):
model = model_class(model_config.hf_config, model_config.quantization_config)
model = model_class(
model_config.hf_config,
model_config.quantization_config
)
else:
model = model_class(model_config.hf_config)

Expand Down
118 changes: 54 additions & 64 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
assert tp_size == 1, 'quantization does not support TP'
assert tp_size == 1, "quantization does not support TP"
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
Expand All @@ -178,7 +178,7 @@ def __init__(

self.qkv_proj = get_quantized_layer(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
self.q_size + 2 * self.kv_size,
quant_config
)

Expand Down Expand Up @@ -220,8 +220,17 @@ def __init__(
quant_config: QuantizationConfig
):
super().__init__()
self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config)
self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)

self.gate_up_proj = get_quantized_layer(
hidden_size,
2 * intermediate_size, quant_config
)

self.down_proj = get_quantized_layer(
intermediate_size,
hidden_size,
quant_config
)

if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
Expand Down Expand Up @@ -313,9 +322,12 @@ def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)

self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers)
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand Down Expand Up @@ -414,82 +426,60 @@ def load_weights(self,
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

is_quantized = self.quant_config is not None and self.quant_config.method is not None
is_quantized = self.quant_config is not None

# merge linear layers
if not is_quantized:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]

if not is_quantized:
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
else:
# TODO: this is specific to AWQ
if "qweight" in name or "qzeros" in name:
adjustment = 32 / self.quant_config.bits
shard_size = int(shard_size // adjustment)
offset = int(offset // adjustment)
param_slice = param.data[:, offset:offset + shard_size]

assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if not is_quantized:
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
else:
# TODO: improve this block of code (not DRY, hacky, specific to AWQ)
is_attention_weight = False
for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]

# TODO: this is specific to AWQ (should be more general)
if 'qweight' in name or 'qzeros' in name:
shard_size = int(shard_size // (32 / self.quant_config.bits))
offset = int(offset // (32 / self.quant_config.bits))

param_slice = param.data[:, offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
else:
shard_size = param.shape[1] // 2

start, end = shard_size * stride_id, shard_size * (stride_id + 1)
start = shard_size * stride_id
end = shard_size * (stride_id + 1)
param_slice = param.data[:, start:end]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
Expand Down

0 comments on commit 73db30f

Please sign in to comment.