Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: implement YaRN RoPE scaling #2268

Merged
merged 36 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8dec38c
llama: implement NTK-By-Parts (NTKv2) RoPE scaling
cebtenzzre Jul 18, 2023
6aeb46b
CUDA implementation
cebtenzzre Jul 19, 2023
9348aa4
Metal implementation
cebtenzzre Jul 21, 2023
a30ae20
implement new YaRN algorithm
cebtenzzre Sep 5, 2023
b5ced4f
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Sep 5, 2023
826269a
ggml : increase GGML_MAX_OP_PARAMS
cebtenzzre Sep 5, 2023
cf731d5
YaRN : avoid NaN if unused betas are zero
cebtenzzre Sep 5, 2023
dcb058c
YaRN : fix missing parameter in CUDA impl
cebtenzzre Sep 5, 2023
281b26e
convert : reduce unnecessary variables in Params
cebtenzzre Sep 6, 2023
a06c729
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Sep 21, 2023
dc26a0d
llama : simplify use of context params
cebtenzzre Sep 21, 2023
904d4ed
llama : store YaRN parameters in GGUF
cebtenzzre Sep 14, 2023
56abb9a
fix convert scripts
cebtenzzre Sep 21, 2023
43eaf06
llama : fix C compatibility
cebtenzzre Sep 21, 2023
fe788c4
don't hardcode max_pos_emb
cebtenzzre Sep 21, 2023
e0b120c
address review comments
cebtenzzre Sep 21, 2023
19bb74e
restore backwards compatiblity with *.rope.scale_linear
cebtenzzre Sep 21, 2023
4d5fe73
better option descriptions in help
cebtenzzre Sep 21, 2023
7466415
gguf : store scaling type as a string instead of an int
cebtenzzre Oct 7, 2023
4f4e948
improve printing of YaRN parameters
cebtenzzre Oct 7, 2023
5d7a3a5
allow forcing ext_factor to zero if scaling type is YaRN
cebtenzzre Oct 7, 2023
9bd050f
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 7, 2023
babf0e0
fix rope_cuda parameter order
cebtenzzre Oct 8, 2023
0050e1e
default n_yarn_orig_ctx to n_ctx_train
cebtenzzre Oct 8, 2023
09c3102
fix uninitialized cparams
cebtenzzre Oct 8, 2023
57c3442
make printed param formatting more consistent
cebtenzzre Oct 8, 2023
a20b3e6
fix missing import
cebtenzzre Oct 11, 2023
9ef91b1
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 13, 2023
9ae10b3
Fix YaRN inverted scaling and add "rope.scaling.type" to GGUF (#1)
jquesnelle Oct 20, 2023
14cf93b
fix YaRN ramp, make mscale conditional, add --yarn-orig-ctx (#2)
jquesnelle Oct 20, 2023
237f1e7
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 22, 2023
bc8395d
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 23, 2023
4d5ed83
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 24, 2023
9fc8238
fix loading rope.scaling.original_context_length from GGUF (#3)
jquesnelle Oct 30, 2023
15f26ef
implement YaRN for GPT-NeoX RoPE
cebtenzzre Nov 1, 2023
081f738
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
Expand Down Expand Up @@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
Expand Down Expand Up @@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;

return cparams;
}
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define LOG_NO_FILE_LINE_FUNCTION
#include "log.h"

#include <cmath>
#include <string>
#include <vector>
#include <random>
Expand Down Expand Up @@ -54,6 +55,12 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f;// YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
3 changes: 2 additions & 1 deletion convert-baichuan-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def parse_args() -> argparse.Namespace:
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
if "type" in hparams["rope_scaling"]:
if hparams["rope_scaling"]["type"] == "linear":
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])


# TOKENIZATION
Expand Down
97 changes: 48 additions & 49 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ class Params:
n_head_kv: int
f_norm_eps: float

rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None
f_rope_scale: float | None = None
n_orig_ctx: int | None = None
rope_finetuned: bool | None = None

ftype: GGMLFileType | None = None

Expand Down Expand Up @@ -198,20 +201,20 @@ def guessed(model: LazyModel) -> Params:
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))

n_vocab = config["vocab_size"]
n_embd = config["hidden_size"]
n_layer = config["num_hidden_layers"]
n_ff = config["intermediate_size"]
n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
f_norm_eps = config["rms_norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None

rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
rope_scaling = config.get("rope_scaling")
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor")
else:
f_rope_scale = None

if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = gguf.RopeScalingType.LINEAR
elif typ == "yarn":
rope_scaling_type = gguf.RopeScalingType.YARN
n_orig_ctx = rope_scaling['original_max_position_embeddings']
rope_finetuned = rope_scaling['finetuned']
Green-Sky marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError(f'Unknown rope scaling type: {typ}')

if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"]
Expand All @@ -222,16 +225,19 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")

return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_layer = n_layer,
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
f_rope_scale = f_rope_scale,
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
n_layer = config["num_hidden_layers"],
n_ctx = n_ctx,
n_ff = config["intermediate_size"],
n_head = (n_head := config["num_attention_heads"]),
n_head_kv = config.get("num_key_value_heads", n_head),
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
rope_scaling_type = rope_scaling_type,
f_rope_scale = f_rope_scale,
n_orig_ctx = n_orig_ctx,
rope_finetuned = rope_finetuned,
)

# LLaMA v2 70B params.json
Expand All @@ -240,17 +246,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))

n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None

# hack to determine LLaMA v1 vs v2 vs CodeLlama
if f_rope_freq_base == 1000000:
if config.get("rope_theta") == 1000000:
# CodeLlama
n_ctx = 16384
elif config["norm_eps"] == 1e-05:
Expand All @@ -260,22 +257,16 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
# LLaMA v1
n_ctx = 2048

if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]

if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]

return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_layer = n_layer,
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
n_embd = config["dim"],
n_layer = config["n_layers"],
n_ctx = n_ctx,
n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head),
f_norm_eps = config["norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
)

@staticmethod
Expand Down Expand Up @@ -831,8 +822,16 @@ def add_meta_arch(self, params: Params) -> None:
if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)

if params.f_rope_scale is not None:
self.gguf.add_rope_scale_linear(params.f_rope_scale)
if params.rope_scaling_type:
assert params.f_rope_scale is not None
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
self.gguf.add_rope_scaling_factor(params.f_rope_scale)

if params.n_orig_ctx is not None:
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)

if params.rope_finetuned is not None:
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)

if params.ftype is not None:
self.gguf.add_file_type(params.ftype)
Expand Down
5 changes: 3 additions & 2 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
);
};

set_name(tokens_input, "tokens_input");
Expand Down
59 changes: 55 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported())
Expand Down Expand Up @@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--rope-scaling")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base")
{
if (++i >= argc)
Expand All @@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.rope_freq_scale = std::stof(argv[i]);
}
else if (arg == "--yarn-ext-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_ext_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-attn-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_attn_factor = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-fast")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_fast = std::stof(argv[i]);
}
else if (arg == "--yarn-beta-slow")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
Expand Down
6 changes: 3 additions & 3 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
return ggml_rope_custom(
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

set_name(tokens_input, "tokens_input");
Expand Down
Loading