From f64dea0821bccb3a4e486bfe7c827231a33466bc Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 25 Apr 2024 15:55:34 +0900 Subject: [PATCH 01/16] added implementation of DRY sampler --- common/sampling.cpp | 17 ++++++++++++++++- common/sampling.h | 4 ++++ llama.h | 12 ++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f2466550168a7..e9fdd10a43932 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const int dry_allowed_length = params.dry_allowed_length; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + // repetition penalties llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // DRY penalties (multiplier > 0 means enabled) + if(dry_multiplier > 0.0f) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + } + + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index cf7081e3674f1..bfc338ef70f1e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,6 +41,9 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + int dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -61,6 +64,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; + std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.h b/llama.h index 0eb2a1e9ab0a2..0c6b86c16323c 100644 --- a/llama.h +++ b/llama.h @@ -924,6 +924,18 @@ extern "C" { float p, size_t min_keep); + /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + int last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * seq_breakers, + int seq_breakers_size); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, From aea4ad0296a450195b46d82b286ea075fe8c89f3 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 25 Apr 2024 15:57:54 +0900 Subject: [PATCH 02/16] fixed editor config check --- common/sampling.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e9fdd10a43932..ad6dba83da48b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -326,7 +326,6 @@ static llama_token_data_array llama_sampling_prepare_impl( penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); } - if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { From 4d603e3520b8cba69b99a27c9f9b0d77e0e36439 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 25 Apr 2024 15:58:59 +0900 Subject: [PATCH 03/16] added DRY implementation --- llama.cpp | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/llama.cpp b/llama.cpp index 3a84b4916bd30..bb5aff46f800d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // sanity check + GGML_ASSERT(last_tokens_size > 0); + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token if it's not the same as the last token + if(last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // try to extend the match backwards (match length starts a 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our last tokens, break + if(i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; + + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than our allowed length in config, we apply penalities + if(match_length > dry_allowed_length) { + + // find our next token in the candidates->data + size_t i = 0; + for (; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; From 75beda2a843eabdc98882e9035e7d0ff056badb0 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:01:50 +0900 Subject: [PATCH 04/16] fixed various issues with sampler pointed out by original creator --- common/sampling.cpp | 10 ++++----- common/sampling.h | 4 ++-- llama.cpp | 50 +++++++++++++++++++++++++-------------------- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index ad6dba83da48b..92cd76e1d4108 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -275,9 +275,9 @@ static llama_token_data_array llama_sampling_prepare_impl( const bool penalize_nl = params.penalize_nl; // DRY sampler parameters - const float dry_multiplier = params.dry_multiplier; - const float dry_base = params.dry_base; - const int dry_allowed_length = params.dry_allowed_length; + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const uint32_t dry_allowed_length = params.dry_allowed_length; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -320,11 +320,11 @@ static llama_token_data_array llama_sampling_prepare_impl( penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); // DRY penalties (multiplier > 0 means enabled) - if(dry_multiplier > 0.0f) { + if (dry_multiplier > 0.0f) { llama_sample_dry(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); } if (!penalize_nl) { diff --git a/common/sampling.h b/common/sampling.h index bfc338ef70f1e..09df8edc3b514 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -43,7 +43,7 @@ typedef struct llama_sampling_params { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; - int dry_allowed_length = 2; + uint32_t dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -64,7 +64,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; - std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler + std::vector dry_seq_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.cpp b/llama.cpp index bb5aff46f800d..709baee6302b9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,15 +13233,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { - // sanity check - GGML_ASSERT(last_tokens_size > 0); +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { + // skip dry sampler if we don't have a previous token + if (last_tokens_size < 1) return; // get the last token auto last_token = last_tokens[last_tokens_size - 1]; // if last token is part of the sequence breakers, skip whole sampler - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { return; } @@ -13250,21 +13250,26 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi // loop through each previous token (exclude the last token) for (size_t i = 0; i < last_tokens_size - 1; ++i) { - // skip if the compare token if it's not the same as the last token - if(last_tokens[i] != last_token) { + // skip if the compare token is not the same as the last token + if (last_tokens[i] != last_token) { continue; } // get the next token (i + 1 is always less than last_tokens_size) auto next_token = last_tokens[i + 1]; - // try to extend the match backwards (match length starts a 1 because last token is already matched) + // if next token is part of the sequence breakers, skip + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + continue; + } + + // try to extend the match backwards (match length starts at 1 because last token is already matched) size_t match_length = 1; // loop through the previous tokens - for(;; match_length++) { + for (;; match_length++) { // if we have reached the start of our last tokens, break - if(i < match_length) break; + if (i < match_length) break; // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[i - match_length]; @@ -13272,13 +13277,15 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself auto head_token = last_tokens[last_tokens_size - 1 - match_length]; - // if compare token is part of the sequence breakers, break out of the match - if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + // break out of the match if any tokens don't match + if (compare_token != head_token) { break; + } - // break out of the match if any tokens don't match - if(compare_token != head_token) + // if compare token is part of the sequence breakers, break out of the match + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { break; + } } // Check if the next token exists in the map @@ -13298,12 +13305,11 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi auto next_token = pair.first; auto match_length = pair.second; - // if the match length is greater than our allowed length in config, we apply penalities - if(match_length > dry_allowed_length) { + // if the match length is greater than or equal to our allowed length in config, we apply penalities + if (match_length >= dry_allowed_length) { // find our next token in the candidates->data - size_t i = 0; - for (; i < candidates->size; ++i) { + for (size_t i = 0; i < candidates->size; ++i) { if (candidates->data[i].id == next_token) { // calculate the penalty float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); @@ -13444,7 +13450,7 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c const int64_t t_start_sample_us = ggml_time_us(); // no need to do anything if there is only one (or zero) candidates - if(candidates_p->size <= 1) { + if (candidates_p->size <= 1) { return; } @@ -13678,7 +13684,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates->data[X_idx].p); @@ -13700,7 +13706,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok llama_sample_softmax(ctx, candidates); // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + candidates->size = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return -log2f(candidate.p) > *mu; })); @@ -13720,7 +13726,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates->data[X_idx].p); @@ -15770,7 +15776,7 @@ uint64_t llama_model_n_params(const struct llama_model * model) { } struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { - auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(), + auto it = std::find_if (model->tensors_by_name.begin(), model->tensors_by_name.end(), [name](const std::pair & it) { return it.first == name; }); From 85dadac483df1ba5071362f8284599fb6172047b Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:20:17 +0900 Subject: [PATCH 05/16] added parameter for DRY penalty range, separate from the original repetition penalty range --- common/sampling.cpp | 51 ++++++++++++++++++++++++++------------------- common/sampling.h | 3 ++- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 92cd76e1d4108..f400aa7fb136d 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -278,6 +278,7 @@ static llama_token_data_array llama_sampling_prepare_impl( const float dry_multiplier = params.dry_multiplier; const float dry_base = params.dry_base; const uint32_t dry_allowed_length = params.dry_allowed_length; + const uint32_t dry_penalty_last_n = params.dry_penalty_last_n; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -308,35 +309,41 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - // repetition penalties - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - // DRY penalties (multiplier > 0 means enabled) - if (dry_multiplier > 0.0f) { - llama_sample_dry(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); - } - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; + // apply repetition penalties + { + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + // repetition penalties + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } } } } } + // apply DRY penalties + { + const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); + if (penalty_tokens_used_size) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); + } + } + // apply grammar checks before sampling logic if (apply_grammar && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); diff --git a/common/sampling.h b/common/sampling.h index 09df8edc3b514..4ad726c89aa5a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,9 +41,10 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context - float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; uint32_t dry_allowed_length = 2; + uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) std::vector samplers_sequence = { llama_sampler_type::TOP_K, From 793e1e221b176b6cc9f4c96695ccf64225b7ea50 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:22:58 +0900 Subject: [PATCH 06/16] updated header def for dry sampler to match implementation --- llama.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.h b/llama.h index 0c6b86c16323c..cb58daea096e4 100644 --- a/llama.h +++ b/llama.h @@ -933,8 +933,8 @@ extern "C" { float dry_base, float dry_multiplier, int dry_allowed_length, - const llama_token * seq_breakers, - int seq_breakers_size); + const llama_token * dry_seq_breakers, + int dry_seq_breakers_size); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( From 3caec6bb41f9a5b40e432a44fd2acfd55fa01a85 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:25:25 +0900 Subject: [PATCH 07/16] removed unused llama_context in dry sampler --- common/sampling.cpp | 2 +- llama.cpp | 2 +- llama.h | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f400aa7fb136d..a197011de2c3c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -337,7 +337,7 @@ static llama_token_data_array llama_sampling_prepare_impl( { const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); if (penalty_tokens_used_size) { - llama_sample_dry(ctx_main, &cur_p, + llama_sample_dry(&cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); diff --git a/llama.cpp b/llama.cpp index 709baee6302b9..a4cb1f284a693 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,7 +13233,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { // skip dry sampler if we don't have a previous token if (last_tokens_size < 1) return; diff --git a/llama.h b/llama.h index cb58daea096e4..774c1b222d43e 100644 --- a/llama.h +++ b/llama.h @@ -926,7 +926,6 @@ extern "C" { /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 LLAMA_API void llama_sample_dry( - struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, From 49e078f79d38da31cdef7f8c4eb80ce25ade3ecf Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:58:26 +0900 Subject: [PATCH 08/16] changed array size parameters to size_t --- llama.cpp | 2 +- llama.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index a4cb1f284a693..d9e87f34f6cf1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,7 +13233,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { // skip dry sampler if we don't have a previous token if (last_tokens_size < 1) return; diff --git a/llama.h b/llama.h index 774c1b222d43e..fbba5daf8f729 100644 --- a/llama.h +++ b/llama.h @@ -928,12 +928,12 @@ extern "C" { LLAMA_API void llama_sample_dry( llama_token_data_array * candidates, const llama_token * last_tokens, - int last_tokens_size, + size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, - int dry_seq_breakers_size); + size_t dry_seq_breakers_size); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( From 802ddd78bf9cbc12336f95ae6e9af321974462d3 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 19:41:47 +0900 Subject: [PATCH 09/16] added sample_dry_impl --- src/llama-sampling.cpp | 90 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d6542e9..d41218c702360 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,6 +232,96 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra } } +void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { + // skip dry sampler if we don't have a previous token + if (last_tokens_size < 1) return; + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token is not the same as the last token + if (last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // if next token is part of the sequence breakers, skip + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + continue; + } + + // try to extend the match backwards (match length starts at 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for (;; match_length++) { + // if we have reached the start of our last tokens, break + if (i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // break out of the match if any tokens don't match + if (compare_token != head_token) { + break; + } + + // if compare token is part of the sequence breakers, break out of the match + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { + break; + } + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than or equal to our allowed length in config, we apply penalities + if (match_length >= dry_allowed_length) { + + // find our next token in the candidates->data + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; From 12bfa7820caf31809a2dfa6509a90acc214ad67c Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 19:44:23 +0900 Subject: [PATCH 10/16] added llama_sample_dry_impl in header --- src/llama-sampling.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef706bc..578c472438709 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -28,6 +28,7 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size); void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); From 0229fc82558bbcc7211c319a0f327402a97d4349 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 20:12:46 +0900 Subject: [PATCH 11/16] added final new line for editor config check --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4ffcd9d6af85e..86b5e4c052413 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19161,4 +19161,4 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * (void) user_data; fputs(text, stderr); fflush(stderr); -} \ No newline at end of file +} From 236da599d4e9c010265caf9e30cf6b904e6f88fe Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 20:25:56 +0900 Subject: [PATCH 12/16] fixed int/size_t comparison --- src/llama-sampling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d41218c702360..375717accbd5d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -305,7 +305,7 @@ void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_toke auto match_length = pair.second; // if the match length is greater than or equal to our allowed length in config, we apply penalities - if (match_length >= dry_allowed_length) { + if (match_length >= (size_t)dry_allowed_length) { // find our next token in the candidates->data for (size_t i = 0; i < candidates->size; ++i) { From e862defaa99e1d25b8085146c10cee5caada5aca Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 20:53:42 +0900 Subject: [PATCH 13/16] use int32_t for dry_penalty_last_n due to negative value needed as config --- common/sampling.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/sampling.h b/common/sampling.h index c38d921bf307a..80c2568cf2b41 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -44,7 +44,7 @@ typedef struct llama_sampling_params { float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; uint32_t dry_allowed_length = 2; - uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) + int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) std::vector samplers_sequence = { llama_sampler_type::TOP_K, From 9105cf435bb0badbf60f7ea8b3437151e2404f9c Mon Sep 17 00:00:00 2001 From: wwoodsTM Date: Mon, 5 Aug 2024 00:03:38 -0600 Subject: [PATCH 14/16] Add DRY sampling parameters to gpt_params and server_context --- common/common.cpp | 25 ++++++++++++ examples/server/server.cpp | 78 ++++++++++++++++++++++++++------------ pr-6839.diff | 0 3 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 pr-6839.diff diff --git a/common/common.cpp b/common/common.cpp index 60c7eac75c613..4cf094b7516cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.penalty_present = std::stof(argv[i]); return true; } + if (arg == "--dry-multiplier") { + CHECK_ARG + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + CHECK_ARG + sparams.dry_base = std::stof(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + CHECK_ARG + sparams.dry_allowed_length = std::stoi(argv[i]); + return true; + } + if (arg == "--dry-penalty-last-n") { + CHECK_ARG + sparams.dry_penalty_last_n = std::stoi(argv[i]); + return true; + } if (arg == "--dynatemp-range") { CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); @@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); + options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier }); + options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base }); + options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length }); + options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n }); + options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7813a2957d6bc..4b2654db9b767 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -898,30 +898,55 @@ struct server_context { slot.oaicompat_model = ""; } - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.stream = json_value(data, "stream", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", false); + slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + // sequence breakers for DRY + { + auto dry_seq_breakers = data.find("dry_seq_breakers"); + if (dry_seq_breakers != data.end()) { + try { + if (dry_seq_breakers->is_array()) { + slot.sparams.dry_seq_breakers = dry_seq_breakers->get>(); + } else if (dry_seq_breakers->is_string()) { + slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get()).get>(); + } else { + send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } catch (const std::exception & e) { + send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -1339,6 +1364,11 @@ struct server_context { {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_seq_breakers", slot.sparams.dry_seq_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, diff --git a/pr-6839.diff b/pr-6839.diff new file mode 100644 index 0000000000000..e69de29bb2d1d From 20dc562f45434b105b3d167830e927c054600e41 Mon Sep 17 00:00:00 2001 From: wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> Date: Mon, 5 Aug 2024 00:41:26 -0600 Subject: [PATCH 15/16] Delete pr-6839.diff --- pr-6839.diff | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 pr-6839.diff diff --git a/pr-6839.diff b/pr-6839.diff deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 6579e64f26d72f8e8f1009f10d8b0a791707e8d4 Mon Sep 17 00:00:00 2001 From: wwoodsTM Date: Tue, 6 Aug 2024 02:54:57 -0600 Subject: [PATCH 16/16] Attempt at slightly optimized vector of strings DRY implementation --- common/sampling.cpp | 4 +- common/sampling.h | 5 +- include/llama.h | 33 +++++-- src/llama-sampling.cpp | 210 +++++++++++++++++++++++++++++++++-------- src/llama-sampling.h | 6 +- src/llama.cpp | 4 +- 6 files changed, 208 insertions(+), 54 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index c626ca03c11e1..bb7de3769d145 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl( { const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); if (penalty_tokens_used_size) { - llama_sample_dry(&cur_p, + llama_sample_dry(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); + params.dry_seq_breakers); } } diff --git a/common/sampling.h b/common/sampling.h index 80c2568cf2b41..1f864c4764764 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -46,6 +46,8 @@ typedef struct llama_sampling_params { uint32_t dry_allowed_length = 2; int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) + std::vector dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY + std::vector samplers_sequence = { llama_sampler_type::TOP_K, llama_sampler_type::TFS_Z, @@ -63,9 +65,8 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens - std::vector penalty_prompt_tokens; - std::vector dry_seq_breakers; // sequence breakers for the DRY sampler + bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/include/llama.h b/include/llama.h index 51ed8d9ee2402..81805e5c2ae7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1085,16 +1085,17 @@ extern "C" { float p, size_t min_keep); - /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 - LLAMA_API void llama_sample_dry( - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float dry_base, - float dry_multiplier, - int dry_allowed_length, - const llama_token * dry_seq_breakers, - size_t dry_seq_breakers_size); + // /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + // LLAMA_API void llama_sample_dry( + // struct llama_context * ctx, + // llama_token_data_array * candidates, + // const llama_token * last_tokens, + // size_t last_tokens_size, + // float dry_base, + // float dry_multiplier, + // int dry_allowed_length, + // const std::vector + // & dry_seq_breakers); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( @@ -1246,6 +1247,18 @@ std::pair, llama_partial_utf8> decode_utf8( // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); +/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 +LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const std::vector + & dry_seq_breakers); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 375717accbd5d..22e623fac7656 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra } } -void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { - // skip dry sampler if we don't have a previous token - if (last_tokens_size < 1) return; +std::vector llama_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special) { + return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); +} + +std::vector llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} + +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) { + std::vector tokens = {token}; + return llama_detokenize(ctx, tokens, special); +} - // get the last token - auto last_token = last_tokens[last_tokens_size - 1]; +// Constants for preventing overflow +const float FLOAT_MAX_LOG = 88.7228391f; +const int MAX_CHAR_LEN = 40; +const int MAX_SEQ_LEN = 20; - // if last token is part of the sequence breakers, skip whole sampler - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { + +void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + if (last_tokens_size < 1) { return; } - // create an unordered map of "next tokens" <-> max match length + // Cache for token-to-string conversions + std::unordered_map token_to_string_cache; + // Store sequence breakers for more efficient lookup + std::unordered_multimap> restart_sequences; + + auto detokenize_with_cache = [&](llama_token token) -> std::string { + auto it = token_to_string_cache.find(token); + if (it != token_to_string_cache.end()) { + return it->second; + } + std::string token_str = llama_detokenize_single(ctx, token, false); + token_to_string_cache[token] = token_str; + return token_str; + }; + + // Pre-process dry_seq_breakers + for (const auto& breaker : dry_seq_breakers) { + std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN); + std::vector tokens = llama_tokenize(ctx, breaker_trimmed, false, false); + + if (!tokens.empty()) { + std::string head = detokenize_with_cache(tokens[0]); + std::vector tail; + + for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) { + tail.push_back(detokenize_with_cache(tokens[i])); + } + restart_sequences.emplace(head, tail); + } + } + + // Find max repetition length considering restart sequences + int rep_limit = last_tokens_size; + + for (size_t i = 0; i < last_tokens_size; ++i) { + size_t ix = last_tokens_size - 1 - i; + std::string token_str = detokenize_with_cache(last_tokens[ix]); + + // Check if the token is a potential sequence breaker + auto its = restart_sequences.equal_range(token_str); + if (its.first == restart_sequences.end()) continue; + + int longest_match = -1; + // Check all potential sequence breakers starting with this token + for (auto it = its.first; it != its.second; ++it) { + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= i) { + bool match = true; + // Check if the following tokens match the sequence breaker + for (size_t offset = 0; offset < seq_len; ++offset) { + if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + + if (longest_match >= 0) { + rep_limit = static_cast(i) - longest_match; + break; + } + } + + if (rep_limit <= dry_allowed_length) { + return; + } + + // Store max match length for each token std::unordered_map match_lengths; - // loop through each previous token (exclude the last token) + // Find repeated sequences for (size_t i = 0; i < last_tokens_size - 1; ++i) { - // skip if the compare token is not the same as the last token - if (last_tokens[i] != last_token) { + if (last_tokens[i] != last_tokens[last_tokens_size - 1]) { continue; } - // get the next token (i + 1 is always less than last_tokens_size) auto next_token = last_tokens[i + 1]; + std::string next_token_str = detokenize_with_cache(next_token); - // if next token is part of the sequence breakers, skip - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + // Skip if next token is a sequence breaker + auto its = restart_sequences.equal_range(next_token_str); + if (its.first != restart_sequences.end()) { continue; } - // try to extend the match backwards (match length starts at 1 because last token is already matched) size_t match_length = 1; - // loop through the previous tokens + // Extend match as far as possible for (;; match_length++) { - // if we have reached the start of our last tokens, break - if (i < match_length) break; + if (i < match_length || match_length > rep_limit) { + break; + } - // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[i - match_length]; + std::string compare_token_str = detokenize_with_cache(compare_token); - // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + std::string head_token_str = detokenize_with_cache(head_token); - // break out of the match if any tokens don't match - if (compare_token != head_token) { + if (compare_token_str != head_token_str) { break; } - // if compare token is part of the sequence breakers, break out of the match - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { + // Check if we've hit a sequence breaker + its = restart_sequences.equal_range(compare_token_str); + if (its.first != restart_sequences.end()) { break; } } - // Check if the next token exists in the map + // Update max match length for this token auto it = match_lengths.find(next_token); - if (it == match_lengths.end()) { - // Key does not exist, insert the new value match_lengths[next_token] = match_length; } else { - // Key exists, update it with the max of the new value or the existing value it->second = std::max(it->second, match_length); } } - // apply penalties + // Calculate max safe exponent + int max_exponent = 0; + if (dry_base > 1.000001f) { + max_exponent = static_cast(FLOAT_MAX_LOG / log(dry_base)); + } + +#ifdef DEBUG + LLAMA_LOG_INFO("DRY Sampling parameters:\n"); + LLAMA_LOG_INFO(" dry_base: %f\n", dry_base); + LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier); + LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length); + LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent); + LLAMA_LOG_INFO("DRY penalties ["); +#endif + + // Apply penalties for (const auto& pair : match_lengths) { auto next_token = pair.first; auto match_length = pair.second; - // if the match length is greater than or equal to our allowed length in config, we apply penalities - if (match_length >= (size_t)dry_allowed_length) { - - // find our next token in the candidates->data + if (match_length >= static_cast(dry_allowed_length)) { for (size_t i = 0; i < candidates->size; ++i) { if (candidates->data[i].id == next_token) { - // calculate the penalty - float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); - - // apply the dry penalty + int repeat_exp = static_cast(match_length - dry_allowed_length); + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = dry_multiplier * pow(dry_base, static_cast(repeat_exp)); candidates->data[i].logit -= penalty; + +#ifdef DEBUG + LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty); +#endif break; } } } } + +#ifdef DEBUG + LLAMA_LOG_INFO("]\n"); +#endif } void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 578c472438709..48cdc086c820c 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size); +std::vector llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special); +std::vector llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special); +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special); +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special); +void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers); void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); diff --git a/src/llama.cpp b/src/llama.cpp index 86b5e4c052413..fc2009b7501c6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18935,8 +18935,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); } -void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { - llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size); +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers); } void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {