Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor GGUF parameters packing and forwarding #8859

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/models/decoder_only/language/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

# FIXME: Move this to confest
MODELS = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
Expand Down
76 changes: 32 additions & 44 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,17 +440,23 @@ def weight_loader(self,
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -515,18 +521,6 @@ def weight_loader(self,
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
Expand Down Expand Up @@ -783,17 +777,23 @@ def weight_loader(self,
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -883,18 +883,6 @@ def weight_loader(self,
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
Expand Down
36 changes: 25 additions & 11 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ def create_weights(self, layer: torch.nn.Module,
output_size_per_partition = sum(output_partition_sizes)

tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = UninitializedParameter(requires_grad=False)
qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"shard_size": {},
"data_container": [],
"shard_id": [],
"shard_id_map": {},
})
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
Expand All @@ -116,21 +117,17 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_size = getattr(layer.qweight, "shard_size", None)
shard_id = getattr(layer.qweight, "shard_id", None)

if shard_id and shard_size:
result = []
offset = 0
if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0)
result = []
for id in shard_id:
shard_weight = layer.qweight[
offset:offset +
shard_size[id][0], :shard_size[id][1]].contiguous()
q_idx = layer.qweight.shard_id_map[id]
qweight_type = layer.qweight_type.shard_weight_type[id]
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
offset += shard_size[id][0]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
Expand Down Expand Up @@ -162,3 +159,20 @@ def embedding(self, layer: torch.nn.Module,
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0])
return dequant.view(*x.shape, hidden_size)


class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: List[torch.Tensor]

def materialize_nested(self) -> Parameter:
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=torch.uint8)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
require_grad=False)
for k, v in self.__dict__.items():
setattr(param, k, v)
return param
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __init__(
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
Expand Down
Loading