Skip to content

Commit

Permalink
llama : implement YaRN RoPE scaling (#2268)
Browse files Browse the repository at this point in the history
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
  • Loading branch information
cebtenzzre and jquesnelle committed Nov 1, 2023
1 parent c43c2da commit 898aeca
Show file tree
Hide file tree
Showing 15 changed files with 764 additions and 258 deletions.
79 changes: 66 additions & 13 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
Expand Down Expand Up @@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
Expand Down Expand Up @@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;

return cparams;
}
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define LOG_NO_FILE_LINE_FUNCTION
#include "log.h"

#include <cmath>
#include <string>
#include <vector>
#include <random>
Expand Down Expand Up @@ -54,6 +55,12 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f;// YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
3 changes: 2 additions & 1 deletion convert-baichuan-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def parse_args() -> argparse.Namespace:
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
if "type" in hparams["rope_scaling"]:
if hparams["rope_scaling"]["type"] == "linear":
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])


# TOKENIZATION
Expand Down
97 changes: 48 additions & 49 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ class Params:
n_head_kv: int
f_norm_eps: float

rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None
f_rope_scale: float | None = None
n_orig_ctx: int | None = None
rope_finetuned: bool | None = None

ftype: GGMLFileType | None = None

Expand Down Expand Up @@ -198,20 +201,20 @@ def guessed(model: LazyModel) -> Params:
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))

n_vocab = config["vocab_size"]
n_embd = config["hidden_size"]
n_layer = config["num_hidden_layers"]
n_ff = config["intermediate_size"]
n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
f_norm_eps = config["rms_norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None

rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
rope_scaling = config.get("rope_scaling")
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor")
else:
f_rope_scale = None

if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = gguf.RopeScalingType.LINEAR
elif typ == "yarn":
rope_scaling_type = gguf.RopeScalingType.YARN
n_orig_ctx = rope_scaling['original_max_position_embeddings']
rope_finetuned = rope_scaling['finetuned']
else:
raise NotImplementedError(f'Unknown rope scaling type: {typ}')

if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"]
Expand All @@ -222,16 +225,19 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")

return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_layer = n_layer,
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
f_rope_scale = f_rope_scale,
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
n_layer = config["num_hidden_layers"],
n_ctx = n_ctx,
n_ff = config["intermediate_size"],
n_head = (n_head := config["num_attention_heads"]),
n_head_kv = config.get("num_key_value_heads", n_head),
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type,
f_rope_scale = f_rope_scale,
n_orig_ctx = n_orig_ctx,
rope_finetuned = rope_finetuned,
)

# LLaMA v2 70B params.json
Expand All @@ -240,17 +246,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))

n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None

# hack to determine LLaMA v1 vs v2 vs CodeLlama
if f_rope_freq_base == 1000000:
if config.get("rope_theta") == 1000000:
# CodeLlama
n_ctx = 16384
elif config["norm_eps"] == 1e-05:
Expand All @@ -260,22 +257,16 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
# LLaMA v1
n_ctx = 2048

if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]

if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]

return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_layer = n_layer,
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
n_embd = config["dim"],
n_layer = config["n_layers"],
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head),
f_norm_eps = config["norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
)

@staticmethod
Expand Down Expand Up @@ -831,8 +822,16 @@ def add_meta_arch(self, params: Params) -> None:
if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)

if params.f_rope_scale is not None:
self.gguf.add_rope_scale_linear(params.f_rope_scale)
if params.rope_scaling_type:
assert params.f_rope_scale is not None
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
self.gguf.add_rope_scaling_factor(params.f_rope_scale)

if params.n_orig_ctx is not None:
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)

if params.rope_finetuned is not None:
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)

if params.ftype is not None:
self.gguf.add_file_type(params.ftype)
Expand Down
5 changes: 3 additions & 2 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
);
};

set_name(tokens_input, "tokens_input");
Expand Down
59 changes: 55 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported())
Expand Down Expand Up @@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--rope-scaling")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base")
{
if (++i >= argc)
Expand All @@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.rope_freq_scale = std::stof(argv[i]);
}
else if (arg == "--yarn-ext-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-attn-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-fast")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-slow")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
Expand Down
6 changes: 3 additions & 3 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
return ggml_rope_custom(
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

set_name(tokens_input, "tokens_input");
Expand Down
Loading

0 comments on commit 898aeca

Please sign in to comment.