Skip to content

Commit

Permalink
improve the quant weight loaded code
Browse files Browse the repository at this point in the history
  • Loading branch information
ri938 committed Aug 16, 2023
1 parent fbaf889 commit db4db0c
Showing 1 changed file with 36 additions and 62 deletions.
98 changes: 36 additions & 62 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,86 +426,60 @@ def load_weights(self,
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

# merge linear layers
if self.quant_config is not None:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
is_quantized = self.quant_config is not None

loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
if not is_quantized:
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
else:
# TODO: improve this block of code
is_attention_weight = False
for stride_id, weight_spec in enumerate(attention_weight_specs):
weight_name, shard_size, offset = weight_spec

if weight_name not in name:
continue

param = state_dict[name.replace(weight_name, "qkv_proj")]

# TODO: this is specific to AWQ (should be more general)
param_slice = param.data[offset:offset + shard_size]
else:
# TODO: this is specific to AWQ
if "qweight" in name or "qzeros" in name:
adjustment = 32 / self.quant_config.bits
shard_size = int(shard_size // adjustment)
offset = int(offset // adjustment)

param_slice = param.data[:, offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
assert param_slice.shape == loaded_weight.shape

param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if not is_quantized:
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
else:
shard_size = param.shape[1] // 2

start = shard_size * stride_id
end = shard_size * (stride_id + 1)

param_slice = param.data[:, start:end]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
Expand Down

0 comments on commit db4db0c

Please sign in to comment.