Skip to content

Commit

Permalink
llama: add support for small granite models
Browse files Browse the repository at this point in the history
it works only for the small models 3b and 8b.

The convert-hf-to-gguf.py script uses the vocabulary size of the
granite models to detect granite and set the correct configuration.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
  • Loading branch information
giuseppe committed May 24, 2024
1 parent 0211330 commit 431fde0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
23 changes: 19 additions & 4 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,19 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

# Apply to granite small models only
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_rope_type(gguf.RopeType.NEOX)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)

tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
Expand All @@ -1329,10 +1342,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

if name.endswith("q_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith("k_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
# Skip for granite models
if self.hparams.get("vocab_size", 32000) != 49152:
if name.endswith("q_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith("k_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
Expand Down
11 changes: 10 additions & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4001,7 +4001,9 @@ static void llm_load_hparams(
switch (hparams.n_layer) {
case 22: model.type = e_model::MODEL_1B; break;
case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break;
// granite uses a vocab with len 49152
case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break;
case 36: model.type = e_model::MODEL_8B; break; // granite
case 40: model.type = e_model::MODEL_13B; break;
case 48: model.type = e_model::MODEL_34B; break;
case 60: model.type = e_model::MODEL_30B; break;
Expand Down Expand Up @@ -4271,6 +4273,8 @@ static void llm_load_hparams(
case 30: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_15B; break;
case 52: model.type = e_model::MODEL_20B; break; // granite
case 88: model.type = e_model::MODEL_34B; break; // granite
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
Expand Down Expand Up @@ -4521,6 +4525,11 @@ static void llm_load_vocab(
} else {
if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;

const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
}
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
Expand Down

0 comments on commit 431fde0

Please sign in to comment.