diff --git a/common/common.cpp b/common/common.cpp index 10ef11829cc50..8d12a8549e659 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -360,6 +360,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } sparams.min_p = std::stof(argv[i]); + } else if (arg == "--p-step") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.p_step = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -970,6 +976,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --p-step N p-step sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.p_step); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1140,6 +1147,7 @@ std::vector sampler_types_from_names(const std::vector sampler_types_from_names(const std::vector sampler_types_from_chars(const std::string & nam {'p', llama_sampler_type::TOP_P}, {'y', llama_sampler_type::TYPICAL_P}, {'m', llama_sampler_type::MIN_P}, + {'s', llama_sampler_type::P_STEP}, {'f', llama_sampler_type::TFS_Z}, {'t', llama_sampler_type::TEMPERATURE} }; @@ -1210,6 +1220,7 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { case llama_sampler_type::TYPICAL_P: return "typical_p"; case llama_sampler_type::TOP_P: return "top_p"; case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::P_STEP: return "p_step"; case llama_sampler_type::TEMPERATURE: return "temperature"; default : return ""; } @@ -1755,6 +1766,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); + fprintf(stream, "p_step: %f # default: 0.0\n", sparams.p_step); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); diff --git a/common/sampling.cpp b/common/sampling.cpp index de4331a1182d6..d8bc04f519cd1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -91,10 +91,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, p_step = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, + params.top_k, params.tfs_z, params.top_p, params.min_p, params.p_step, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); return std::string(result); @@ -128,6 +128,7 @@ static void sampler_queue( const int32_t top_k = params.top_k; const float top_p = params.top_p; const float min_p = params.min_p; + const float p_step = params.p_step; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const std::vector & samplers_sequence = params.samplers_sequence; @@ -139,6 +140,7 @@ static void sampler_queue( case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::P_STEP : llama_sample_p_step (ctx_main, &cur_p, p_step, min_keep); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); diff --git a/common/sampling.h b/common/sampling.h index 95d8753942b40..db643e6b59cbf 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -13,6 +13,7 @@ enum class llama_sampler_type : char { TOP_K = 'k', TOP_P = 'p', MIN_P = 'm', + P_STEP = 's', TFS_Z = 'f', TYPICAL_P = 'y', TEMPERATURE = 't' @@ -26,6 +27,7 @@ typedef struct llama_sampling_params { int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled + float p_step = 0.00f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities @@ -46,6 +48,7 @@ typedef struct llama_sampling_params { llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, + llama_sampler_type::P_STEP, llama_sampler_type::TEMPERATURE }; diff --git a/llama.cpp b/llama.cpp index 259f2a3a3ea00..b7502ed4f2e72 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9593,6 +9593,34 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_p_step(struct llama_context * ctx, llama_token_data_array * candidates, float step, size_t min_keep) { + if (step <= 0.0f || candidates->size <= 1) { + return; + } + + llama_sample_softmax(nullptr, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + bool step_found = false; + + for (size_t i = 1; i < candidates->size; ++i) { + if (!step_found && candidates->data[i].p < step * candidates->data[i - 1].p) { + step_found = true; + } + + if (step_found && i >= min_keep) { + // Resize the output vector to keep only the tokens before the step + candidates->size = i; + break; + } + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index 84f196b3bb625..cb17bc0eae016 100644 --- a/llama.h +++ b/llama.h @@ -799,6 +799,13 @@ extern "C" { float p, size_t min_keep); + /// @details P-Step sampling as described in [THIS PR] + LLAMA_API void llama_sample_p_step( + struct llama_context * ctx, + llama_token_data_array * candidates, + float step, + size_t min_keep); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx, diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6374958fee8e6..d81bf9091aff9 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -101,6 +101,27 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float step) { + const size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + const float logit = logf(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + DUMP(&candidates_p); + llama_sample_p_step(nullptr, &candidates_p, step, 1); + DUMP(&candidates_p); + llama_sample_softmax(nullptr, &candidates_p); + + GGML_ASSERT(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + } +} + static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); std::vector candidates; @@ -149,7 +170,7 @@ static void test_repetition_penalties( } static void test_sampler_queue( - const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p + const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p, const float p_step ) { std::vector candidates; candidates.reserve(n_vocab); @@ -164,14 +185,15 @@ static void test_sampler_queue( const llama_token max_token_id = n_vocab-1; for (auto s : samplers_sequence) { - switch (s){ - case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; - case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; - case 'y': GGML_ASSERT(false && "typical test not implemented"); break; - case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; - case 't': GGML_ASSERT(false && "temperature test not implemented"); break; - default : GGML_ASSERT(false && "Unknown sampler"); break; + switch (s) { + case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; + case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; + case 'y': GGML_ASSERT(false && "typical test not implemented"); break; + case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; + case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; + case 's': llama_sample_p_step (nullptr, &candidates_p, p_step, 1); break; + case 't': GGML_ASSERT(false && "temperature test not implemented"); break; + default : GGML_ASSERT(false && "Unknown sampler"); break; } llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests @@ -218,6 +240,18 @@ static void test_sampler_queue( min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size)); min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1)); + GGML_ASSERT(size == expected_size); + GGML_ASSERT(candidates_p.data[0].id == max_token_id); + GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); + } else if (s == 's') { + min_token_id = n_vocab; + int expected_size = 0; + + do { // do-while because always at least one token is sampled + min_token_id--; + expected_size++; + } while (candidates_p.data[expected_size].p >= p_step * candidates_p.data[expected_size - 1].p); + GGML_ASSERT(size == expected_size); GGML_ASSERT(candidates_p.data[0].id == max_token_id); GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); @@ -226,8 +260,8 @@ static void test_sampler_queue( } } - printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n", - samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); + printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f p_step=%f\n", + samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p, p_step); } int main(void) { @@ -252,6 +286,17 @@ int main(void) { test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.0f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.5f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.6f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f); + test_p_step({0.2f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f); + // Disabled because of floating point nonsense: 0.3f < 0.75f * 0.4f is true! + //test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.75f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f); + test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f); @@ -267,33 +312,44 @@ int main(void) { test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); - test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); - test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); - test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f); - test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f); - test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f); - test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12); - - test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f); - test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f); - test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f); - test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f); - test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f); - - test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "km", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f); - test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f); - - test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f); - test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "k", 1, 1.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12, 1.0f); + test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1.0f); + test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1e-12); + + test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f, 1.0f); + test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f, 1.0f); + test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f, 1.0f); + test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f, 1.0f); + test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f, 1.0f); + test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 9997.9f/9999.0f); + test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 0.1f); + + test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "km", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f, 1.0f); + test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "ks", 100, 0.8f, 1.0f, 0.1f); + test_sampler_queue(10000, "sk", 100, 0.8f, 1.0f, 0.1f); + test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 9997.9f/9999.0f); + test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 0.1f); + + test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f, 1.0f); + test_sampler_queue(10000, "ksp", 100, 0.8f, 1.0f, 0.1f); + test_sampler_queue(10000, "skp", 100, 0.8f, 1.0f, 0.1f); + test_sampler_queue(10000, "spk", 100, 0.8f, 1.0f, 0.1f); printf("OK\n");