From ccc8e19c3cfd9268e0890cb8b840527abc9ec867 Mon Sep 17 00:00:00 2001 From: Lllama <34464159+pi6am@users.noreply.github.com> Date: Mon, 8 Jul 2024 02:04:59 -0700 Subject: [PATCH] Add the DRY dynamic N-gram anti-repetition sampler The DRY (Do not Repeat Yourself) sampler is a dynamic N-gram repetition penalty that negatively scores tokens that would extend sequences that already appear in the context. See this discussion for a motivation and explanation of the sampler: https://github.com/oobabooga/text-generation-webui/pull/5677 This implementation of DRY mostly aligns with the obabooga version with a few modifications. It uses a more efficient linear scanning algorithm to identify repetitions. It also supports multi-token sequence breakers. As a limitation, this implementation reuses the rep pen range parameter, rather than introducing a new range just for the DRY sampler. There is a separate change to lite.koboldai.net that exposes the DRY sampler parameters to KoboldAI Lite, so none of the embed files have been changed as part of this commit. --- common/common.h | 4 + common/sampling.h | 3 + expose.h | 6 + gpttype_adapter.cpp | 304 +++++++++++++++++++++++++++++++++++++++++++- koboldcpp.py | 25 +++- 5 files changed, 336 insertions(+), 6 deletions(-) diff --git a/common/common.h b/common/common.h index 1cb6c12b3c388..5c07e761edc20 100644 --- a/common/common.h +++ b/common/common.h @@ -113,6 +113,10 @@ struct gpt_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + float dry_multiplier = 0.0f; // penalty multiplier, 0.0 = disabled + float dry_base = 1.75f; // exponential base + int32_t dry_allowed_length = 2; // repeated sequences longer than this are penalized + std::vector dry_restart_sequences; // DRY sequence breakers // DynaTemp! float dynatemp_range = 0.0f; // enables DynaTemp if greater than 0. dynatemp_min = temperature - dt_range, dynatemp_max = temperature + dt_range diff --git a/common/sampling.h b/common/sampling.h index 50fa9ec57da2d..ed7be83204e76 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -53,6 +53,9 @@ typedef struct llama_sampling_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + float dry_multiplier = 0.0f; // DRY penalty scale, 0.0 = disabled + float dry_base = 1.75f; // DRY exponent base, 0.0 = disabled + int32_t dry_allowed_length = 2; // DRY penalizes repeated sequences longer than this bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context diff --git a/expose.h b/expose.h index 02a3edde27214..07d47c950b521 100644 --- a/expose.h +++ b/expose.h @@ -5,6 +5,7 @@ const int stop_token_max = 16; const int ban_token_max = 16; const int tensor_split_max = 16; const int logit_bias_max = 16; +const int dry_seq_break_max = 16; const int images_max = 4; // match kobold's sampler list and order @@ -17,6 +18,7 @@ enum samplers KCPP_SAMPLER_TYP=4, KCPP_SAMPLER_TEMP=5, KCPP_SAMPLER_REP_PEN=6, + KCPP_SAMPLER_DRY=7, KCPP_SAMPLER_MAX }; enum stop_reason @@ -89,6 +91,10 @@ struct generation_inputs const int mirostat = 0; const float mirostat_eta = 0.0f; const float mirostat_tau = 0.0f; + const float dry_multiplier = 0.0f; + const float dry_base = 0.0f; + const int dry_allowed_length = 0.0f; + const char * dry_sequence_breakers[dry_seq_break_max] = {}; const samplers sampler_order[KCPP_SAMPLER_MAX] = {}; const int sampler_len = 0; const bool allow_eos_token = false; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 50057f30ccc7c..5726b1a78c0b3 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "model_adapter.h" #include "otherarch.h" #include "grammar-parser.h" @@ -110,6 +111,10 @@ static std::vector stop_sequence; static std::vector special_stop_sequence; //for stop sequences that don't have a string representation static std::vector banned_tokens; static std::vector banned_token_ids; +static std::vector dry_sequence_break_strings; +static std::unordered_multimap> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token) +static std::vector dry_repeat_count; // Indexed as last_n_tokens +static std::unordered_map dry_max_token_repeat; static std::vector top_picks; static int remaining_tokens = 0; static int stopper_unused_tokens = 0; @@ -309,6 +314,70 @@ static void print_tok_vec_str(std::vector &vec) printf("\n%s", get_tok_vec_str(vec).c_str()); } +// Find tokens that completely contain `str`, either as a single token, or as a sequence of tokens. +// It's important to use a hash map for head tokens because some models have many of them. +// For example, the Llama 3 tokenizer has 6570 tokens containing the period ('.') character. +// Single tokens are allowed to extend past `str` at the front and back. This is to allow, for +// instance, the token '.\n' to be a head for both '.' and '\n'. However if a head token +// begins a multi-token sequence, the head can only extend past `str` at the beginning. The +// tail tokens are generated by tokenizing the remainder. +static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap>& token_sequences) { + for(int v=0;vsecond.empty()) { + empty = true; + break; + } + } + if (!empty) { + token_sequences.emplace(v, std::vector()); + } + } else { + // Check whether a prefix of the string overlaps with a suffix of the token. + // Just do a naive O(N^2) search. + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + // We matched to the end of the string. Since `str` is not contained in `word`, + // there must be trailing letters in `str`. + std::vector tokenization; + TokenizeString(str.substr(i), tokenization, file_format, false); + + // Ensure we don't already have a duplicate matching tokenization. + auto its = token_sequences.equal_range(v); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) + { + token_sequences.emplace(v, tokenization); + } + } + } + } + } +} llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) { @@ -428,6 +497,194 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) candidates->size = last_idx; } +void sample_dry(int n_ctx, int rep_pen_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap>& restart_sequences, llama_token_data_array * candidates) { + if (penalty_multiplier == 0.0f || penalty_base == 0.0f) { + return; + } + if (rep_pen_range < 0) { + rep_pen_range = n_ctx; + } + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); + if (last_n_repeat <= allowed_length) { + return; + } + const llama_token * last_tokens = last_n_tokens.data() + last_n_tokens.size() - last_n_repeat; + + dry_repeat_count.assign(last_n_repeat, 0); + dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is worst-case O(N^2) for perverse restart sequences, but typically will be O(N) since + // most restart sequences are a single token and we use a hash table to check for head token. + + int rep_limit = last_n_repeat; + for (size_t i = 0; i < last_n_repeat; ++i) { + size_t ix = last_n_repeat - 1 - i; + auto its = restart_sequences.equal_range(last_tokens[ix]); + if (its.first == restart_sequences.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= i) { + bool match = true; + for (size_t offset = 0; offset < seq_len; ++offset) { + // The +1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != last_tokens[ix + 1 + offset]) { + printf("\n[dry] Match failed at offset %d...", (int)offset); + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = (int)i - longest_match; + break; + } + } + if (rep_limit <= allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // This step is worst case O(N), since the Z-algorithm is linear. + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat and last_tokens[last - n] == last_tokens[last - (n+k)]) { + ++n; + } + dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k+n-1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (dry_repeat_count[last - p] < right_part_len) { + int n = std::min(dry_repeat_count[last - p], rep_limit); + dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat and last_tokens[last - i] == last_tokens[last - (i - k)]) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + dry_repeat_count[last - k] = n; + + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_n_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (size_t i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = dry_repeat_count[i]; + if (repeat_len >= allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + gpt_vocab::id token = last_n_tokens[i + 1]; + // Track the maximum sequence ending in this token. + const auto& it = dry_max_token_repeat.find(token); + if (it == dry_max_token_repeat.end() || it->second < repeat_len) { + dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Determine the max from `penalty_base` and the log of `std::numeric_limits::max()` + constexpr float FLOAT_MAX_LOG = log(std::numeric_limits::max()); + int max_exponent = 0; + if (penalty_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(penalty_base); + } + + for (const auto& kvp: dry_max_token_repeat) { + gpt_vocab::id token = kvp.first; + int repeat_exp = kvp.second - allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = penalty_multiplier * pow(penalty_base, repeat_exp); + if (debugmode==1) + { + std::string tokenizedstr = FileFormatTokenizeID(token, file_format); + ::utreplace(tokenizedstr, "\n", "\\n"); + printf("[dry] Token %d [%s] len %d, penalty %.03f\n", token, RemoveBell(tokenizedstr).c_str(), kvp.second, penalty); + } + candidates->data[token].logit -= penalty; + } +} + void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, llama_token_data_array * candidates_p) { auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); @@ -543,7 +800,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar } int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng, -int mirostat, float mirostat_tau, float mirostat_eta, const std::vector & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) +int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, const std::vector & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) { int id = 0; std::vector candidates; @@ -620,6 +877,9 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vectormirostat = inputs.mirostat; kcpp_params->mirostat_eta = inputs.mirostat_eta; kcpp_params->mirostat_tau = inputs.mirostat_tau; + kcpp_params->dry_multiplier = inputs.dry_multiplier; + kcpp_params->dry_base = inputs.dry_base; + kcpp_params->dry_allowed_length = inputs.dry_allowed_length; kcpp_params->dynatemp_range = inputs.dynatemp_range; kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent; kcpp_params->n_ctx = inputs.max_context_length; kcpp_params->smoothing_factor = inputs.smoothing_factor; + // Parse dry sequence breakers / restart sequences + dry_sequence_break_strings.clear(); + for(int x=0;x0) + { + if(debugmode==1) + { + printf("\nProcessing %zu dry break strings...",dry_sequence_break_strings.size()); + } + for (const auto& sequence_break: dry_sequence_break_strings) { + GetOverlappingTokenSequences(sequence_break, dry_sequence_breakers); + } + if(debugmode==1) + { + int trivial = 0, non_trivial = 0; + for (const auto& seq: dry_sequence_breakers) { + if (seq.second.empty()) { + ++trivial; + } else { + ++non_trivial; + } + } + printf("\nFound a total of %zu restart heads, %d trivial, %d non-trivial.\n", dry_sequence_breakers.size(), trivial, non_trivial); + } + } + bool stream_sse = inputs.stream_sse; bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1; @@ -2073,6 +2370,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) { sampler_order = { KCPP_SAMPLER_REP_PEN, + KCPP_SAMPLER_DRY, KCPP_SAMPLER_TOP_K, KCPP_SAMPLER_TOP_A, KCPP_SAMPLER_TFS, @@ -2336,7 +2634,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs) id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_params->rep_pen_slope, presence_penalty, top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng, - kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); + kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, + kcpp_params->dry_multiplier, kcpp_params->dry_base, kcpp_params->dry_allowed_length, + sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); if (llama_ctx_v4) { empcats_step_post(llama_ctx_v4, id ); diff --git a/koboldcpp.py b/koboldcpp.py index 9344836cf1e56..cbf51694ba670 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -16,11 +16,12 @@ import json, sys, http.server, time, asyncio, socket, threading from concurrent.futures import ThreadPoolExecutor -sampler_order_max = 7 +sampler_order_max = 8 stop_token_max = 16 ban_token_max = 16 tensor_split_max = 16 logit_bias_max = 16 +dry_seq_break_max = 16 images_max = 4 bias_min_value = -100.0 bias_max_value = 100.0 @@ -89,6 +90,10 @@ class generation_inputs(ctypes.Structure): ("mirostat", ctypes.c_int), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), + ("dry_multiplier", ctypes.c_float), + ("dry_base", ctypes.c_float), + ("dry_allowed_length", ctypes.c_int), + ("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max), ("sampler_order", ctypes.c_int * sampler_order_max), ("sampler_len", ctypes.c_int), ("allow_eos_token", ctypes.c_bool), @@ -493,7 +498,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): +def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0, dry_base=0, dry_allowed_length=2, dry_sequence_breakers=['\n'], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): global maxctx, args, currentusergenkey, totalgens, pendingabortkey inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") @@ -541,6 +546,14 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.mirostat_eta = mirostat_eta else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 + inputs.dry_multiplier = dry_multiplier + inputs.dry_base = dry_base + inputs.dry_allowed_length = dry_allowed_length + for n in range(dry_seq_break_max): + if n < len(dry_sequence_breakers): + inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8") + else: + inputs.dry_sequence_breakers[n] = "".encode("UTF-8") if sampler_order and 0 < len(sampler_order) <= sampler_order_max: try: for i, sampler in enumerate(sampler_order): @@ -548,7 +561,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.sampler_len = len(sampler_order) global showsamplerwarning if showsamplerwarning and inputs.mirostat==0 and inputs.sampler_len>0 and (inputs.sampler_order[0]!=6 or inputs.sampler_order[inputs.sampler_len-1]!=5): - print("\n(Note: Non-default sampler_order detected. Recommended sampler values are [6,0,1,3,4,2,5]. This message will only show once per session.)") + print("\n(Note: Sub-optimal sampler_order detected. You may have reduced quality. Recommended sampler values are [6,7,0,1,3,4,2,5]. This message will only show once per session.)") showsamplerwarning = False except TypeError as e: print("ERROR: sampler_order must be a list of integers: " + str(e)) @@ -977,7 +990,11 @@ def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat mirostat=genparams.get('mirostat', 0), mirostat_tau=genparams.get('mirostat_tau', 5.0), mirostat_eta=genparams.get('mirostat_eta', 0.1), - sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), + dry_multiplier=genparams.get('dry_multiplier', 0.8), + dry_base=genparams.get('dry_base', 1.75), + dry_allowed_length=genparams.get('dry_allowed_length', 2), + dry_sequence_breakers=genparams.get('dry_sequence_breakers', []), + sampler_order=genparams.get('sampler_order', [6,7,0,1,3,4,2,5]), seed=tryparseint(genparams.get('sampler_seed', -1)), stop_sequence=genparams.get('stop_sequence', []), use_default_badwordsids=genparams.get('use_default_badwordsids', False),