diff --git a/common/common.cpp b/common/common.cpp index ee7fbcba3c797..721bd714b571c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.penalty_present = std::stof(argv[i]); return true; } + if (arg == "--dry-multiplier") { + CHECK_ARG + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + CHECK_ARG + sparams.dry_base = std::stof(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + CHECK_ARG + sparams.dry_allowed_length = std::stoi(argv[i]); + return true; + } + if (arg == "--dry-penalty-last-n") { + CHECK_ARG + sparams.dry_penalty_last_n = std::stoi(argv[i]); + return true; + } if (arg == "--dynatemp-range") { CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); @@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); + options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier }); + options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base }); + options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length }); + options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n }); + options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff..bb7de3769d145 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -362,13 +362,19 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const uint32_t dry_allowed_length = params.dry_allowed_length; + const uint32_t dry_penalty_last_n = params.dry_penalty_last_n; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -399,26 +405,41 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; + + // apply repetition penalties + { + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + // repetition penalties + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } } } } } + // apply DRY penalties + { + const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); + if (penalty_tokens_used_size) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_seq_breakers); + } + } + // apply grammar checks before sampling logic if (apply_grammar && ctx_sampling->grammar != NULL) { llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd00..1f864c4764764 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,6 +41,12 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + uint32_t dry_allowed_length = 2; + int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) + + std::vector dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -59,8 +65,8 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens - std::vector penalty_prompt_tokens; + bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d178ca0f79b83..da09ae6174796 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -901,30 +901,54 @@ struct server_context { slot.oaicompat_model = ""; } - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.stream = json_value(data, "stream", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", false); + slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + // sequence breakers for DRY + { + auto dry_seq_breakers = data.find("dry_seq_breakers"); + if (dry_seq_breakers != data.end()) { + try { + if (dry_seq_breakers->is_array()) { + slot.sparams.dry_seq_breakers = dry_seq_breakers->get>(); + } else if (dry_seq_breakers->is_string()) { + slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get()).get>(); + } else { + send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } catch (const std::exception & e) { + send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -1342,6 +1366,11 @@ struct server_context { {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_seq_breakers", slot.sparams.dry_seq_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, diff --git a/include/llama.h b/include/llama.h index f23355a6bc959..81805e5c2ae7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1085,6 +1085,18 @@ extern "C" { float p, size_t min_keep); + // /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + // LLAMA_API void llama_sample_dry( + // struct llama_context * ctx, + // llama_token_data_array * candidates, + // const llama_token * last_tokens, + // size_t last_tokens_size, + // float dry_base, + // float dry_multiplier, + // int dry_allowed_length, + // const std::vector + // & dry_seq_breakers); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, @@ -1235,6 +1247,18 @@ std::pair, llama_partial_utf8> decode_utf8( // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); +/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 +LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const std::vector + & dry_seq_breakers); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d6542e9..22e623fac7656 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,6 +232,232 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra } } +std::vector llama_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special) { + return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); +} + +std::vector llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} + +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) { + std::vector tokens = {token}; + return llama_detokenize(ctx, tokens, special); +} + +// Constants for preventing overflow +const float FLOAT_MAX_LOG = 88.7228391f; +const int MAX_CHAR_LEN = 40; +const int MAX_SEQ_LEN = 20; + + +void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + if (last_tokens_size < 1) { + return; + } + + // Cache for token-to-string conversions + std::unordered_map token_to_string_cache; + // Store sequence breakers for more efficient lookup + std::unordered_multimap> restart_sequences; + + auto detokenize_with_cache = [&](llama_token token) -> std::string { + auto it = token_to_string_cache.find(token); + if (it != token_to_string_cache.end()) { + return it->second; + } + std::string token_str = llama_detokenize_single(ctx, token, false); + token_to_string_cache[token] = token_str; + return token_str; + }; + + // Pre-process dry_seq_breakers + for (const auto& breaker : dry_seq_breakers) { + std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN); + std::vector tokens = llama_tokenize(ctx, breaker_trimmed, false, false); + + if (!tokens.empty()) { + std::string head = detokenize_with_cache(tokens[0]); + std::vector tail; + + for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) { + tail.push_back(detokenize_with_cache(tokens[i])); + } + restart_sequences.emplace(head, tail); + } + } + + // Find max repetition length considering restart sequences + int rep_limit = last_tokens_size; + + for (size_t i = 0; i < last_tokens_size; ++i) { + size_t ix = last_tokens_size - 1 - i; + std::string token_str = detokenize_with_cache(last_tokens[ix]); + + // Check if the token is a potential sequence breaker + auto its = restart_sequences.equal_range(token_str); + if (its.first == restart_sequences.end()) continue; + + int longest_match = -1; + // Check all potential sequence breakers starting with this token + for (auto it = its.first; it != its.second; ++it) { + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= i) { + bool match = true; + // Check if the following tokens match the sequence breaker + for (size_t offset = 0; offset < seq_len; ++offset) { + if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + + if (longest_match >= 0) { + rep_limit = static_cast(i) - longest_match; + break; + } + } + + if (rep_limit <= dry_allowed_length) { + return; + } + + // Store max match length for each token + std::unordered_map match_lengths; + + // Find repeated sequences + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + if (last_tokens[i] != last_tokens[last_tokens_size - 1]) { + continue; + } + + auto next_token = last_tokens[i + 1]; + std::string next_token_str = detokenize_with_cache(next_token); + + // Skip if next token is a sequence breaker + auto its = restart_sequences.equal_range(next_token_str); + if (its.first != restart_sequences.end()) { + continue; + } + + size_t match_length = 1; + + // Extend match as far as possible + for (;; match_length++) { + if (i < match_length || match_length > rep_limit) { + break; + } + + auto compare_token = last_tokens[i - match_length]; + std::string compare_token_str = detokenize_with_cache(compare_token); + + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + std::string head_token_str = detokenize_with_cache(head_token); + + if (compare_token_str != head_token_str) { + break; + } + + // Check if we've hit a sequence breaker + its = restart_sequences.equal_range(compare_token_str); + if (its.first != restart_sequences.end()) { + break; + } + } + + // Update max match length for this token + auto it = match_lengths.find(next_token); + if (it == match_lengths.end()) { + match_lengths[next_token] = match_length; + } else { + it->second = std::max(it->second, match_length); + } + } + + // Calculate max safe exponent + int max_exponent = 0; + if (dry_base > 1.000001f) { + max_exponent = static_cast(FLOAT_MAX_LOG / log(dry_base)); + } + +#ifdef DEBUG + LLAMA_LOG_INFO("DRY Sampling parameters:\n"); + LLAMA_LOG_INFO(" dry_base: %f\n", dry_base); + LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier); + LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length); + LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent); + LLAMA_LOG_INFO("DRY penalties ["); +#endif + + // Apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + if (match_length >= static_cast(dry_allowed_length)) { + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + int repeat_exp = static_cast(match_length - dry_allowed_length); + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = dry_multiplier * pow(dry_base, static_cast(repeat_exp)); + candidates->data[i].logit -= penalty; + +#ifdef DEBUG + LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty); +#endif + break; + } + } + } + } + +#ifdef DEBUG + LLAMA_LOG_INFO("]\n"); +#endif +} + void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef706bc..48cdc086c820c 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -28,6 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +std::vector llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special); +std::vector llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special); +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special); +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special); +void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers); void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); diff --git a/src/llama.cpp b/src/llama.cpp index a7b1c9ebd9e37..a8a97c0905dc4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18948,6 +18948,10 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers); +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); }