Skip to content

Commit

Permalink
llama: allow to override model rope type
Browse files Browse the repository at this point in the history
Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
  • Loading branch information
giuseppe committed May 24, 2024
1 parent 120f7bf commit 0211330
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
9 changes: 9 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Attention:
CAUSAL = "{arch}.attention.causal"

class Rope:
TYPE = "{arch}.rope.type"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
Expand Down Expand Up @@ -806,6 +807,13 @@ class TokenType(IntEnum):
BYTE = 6


class RopeType(Enum):
NONE = 'none'
NORM = 'norm'
NEOX = 'neox'
GLM = 'glm'


class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
Expand Down Expand Up @@ -998,6 +1006,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS

# RoPE
KEY_ROPE_TYPE = Keys.Rope.TYPE
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ def add_rope_dimension_count(self, count: int) -> None:
def add_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)

def add_rope_type(self, value: RopeType) -> None:
self.add_string(Keys.Rope.TYPE.format(arch=self.arch), value.value)

def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)

Expand Down
31 changes: 29 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ enum llm_kv {
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_ATTENTION_CAUSAL,

LLM_KV_ROPE_TYPE,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
Expand Down Expand Up @@ -375,6 +376,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },

{ LLM_KV_ROPE_TYPE, "%s.rope.type" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
Expand Down Expand Up @@ -1129,12 +1131,29 @@ struct LLM_TN {
// gguf helpers
//

static const std::map<enum llama_rope_type, const char *> LLAMA_ROPE_TYPES = {
{ LLAMA_ROPE_TYPE_NONE, "none" },
{ LLAMA_ROPE_TYPE_NORM, "norm" },
{ LLAMA_ROPE_TYPE_NEOX, "neox" },
{ LLAMA_ROPE_TYPE_GLM, "glm" },
};

static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
};

static enum llama_rope_type llama_rope_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_TYPES) {
if (kv.second == name) {
return (enum llama_rope_type) kv.first;
}
}

return LLAMA_ROPE_TYPE_NONE;
}

static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
Expand Down Expand Up @@ -4394,7 +4413,15 @@ static void llm_load_hparams(
hparams.use_alibi = true;
}

hparams.rope_type = llama_rope_type(&model);
hparams.rope_type = llama_default_rope_type(&model);

const auto kv = LLM_KV(model.arch);
const int rope_type_keyidx = gguf_find_key(ctx, kv(LLM_KV_ROPE_TYPE).c_str());
if (rope_type_keyidx != -1) {
std::string rope_type("none");
ml.get_key(LLM_KV_ROPE_TYPE, rope_type);
hparams.rope_type = llama_rope_type_from_string(rope_type);
}
}

// TODO: This should probably be in llama.h
Expand Down Expand Up @@ -16216,7 +16243,7 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}

enum llama_rope_type llama_rope_type(const struct llama_model * model) {
enum llama_rope_type llama_default_rope_type(const struct llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
case LLM_ARCH_GPT2:
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ extern "C" {
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_default_rope_type (const struct llama_model * model);

LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
Expand Down

0 comments on commit 0211330

Please sign in to comment.