Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P-Step Truncation Sampling #5675

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.min_p = std::stof(argv[i]);
} else if (arg == "--p-step") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.p_step = std::stof(argv[i]);
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -970,6 +976,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --p-step N p-step sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.p_step);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
Expand Down Expand Up @@ -1140,6 +1147,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"top_p", llama_sampler_type::TOP_P},
{"typical_p", llama_sampler_type::TYPICAL_P},
{"min_p", llama_sampler_type::MIN_P},
{"p_step", llama_sampler_type::P_STEP},
{"tfs_z", llama_sampler_type::TFS_Z},
{"temperature", llama_sampler_type::TEMPERATURE}
};
Expand All @@ -1153,6 +1161,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", llama_sampler_type::TYPICAL_P},
{"min-p", llama_sampler_type::MIN_P},
{"p-step", llama_sampler_type::P_STEP},
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"temp", llama_sampler_type::TEMPERATURE}
Expand Down Expand Up @@ -1188,6 +1197,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'s', llama_sampler_type::P_STEP},
{'f', llama_sampler_type::TFS_Z},
{'t', llama_sampler_type::TEMPERATURE}
};
Expand All @@ -1210,6 +1220,7 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
case llama_sampler_type::TYPICAL_P: return "typical_p";
case llama_sampler_type::TOP_P: return "top_p";
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::P_STEP: return "p_step";
case llama_sampler_type::TEMPERATURE: return "temperature";
default : return "";
}
Expand Down Expand Up @@ -1755,6 +1766,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "p_step: %f # default: 0.0\n", sparams.p_step);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
Expand Down
6 changes: 4 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, p_step = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.p_step, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);

return std::string(result);
Expand Down Expand Up @@ -128,6 +128,7 @@ static void sampler_queue(
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float p_step = params.p_step;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
Expand All @@ -139,6 +140,7 @@ static void sampler_queue(
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::P_STEP : llama_sample_p_step (ctx_main, &cur_p, p_step, min_keep); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
Expand Down
3 changes: 3 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ enum class llama_sampler_type : char {
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
P_STEP = 's',
TFS_Z = 'f',
TYPICAL_P = 'y',
TEMPERATURE = 't'
Expand All @@ -26,6 +27,7 @@ typedef struct llama_sampling_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float p_step = 0.00f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
Expand All @@ -46,6 +48,7 @@ typedef struct llama_sampling_params {
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::P_STEP,
llama_sampler_type::TEMPERATURE
};

Expand Down
28 changes: 28 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9593,6 +9593,34 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_p_step(struct llama_context * ctx, llama_token_data_array * candidates, float step, size_t min_keep) {
if (step <= 0.0f || candidates->size <= 1) {
return;
}

llama_sample_softmax(nullptr, candidates);

const int64_t t_start_sample_us = ggml_time_us();

bool step_found = false;

for (size_t i = 1; i < candidates->size; ++i) {
if (!step_found && candidates->data[i].p < step * candidates->data[i - 1].p) {
step_found = true;
}

if (step_found && i >= min_keep) {
// Resize the output vector to keep only the tokens before the step
candidates->size = i;
break;
}
}

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
Expand Down
7 changes: 7 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,13 @@ extern "C" {
float p,
size_t min_keep);

/// @details P-Step sampling as described in [THIS PR]
LLAMA_API void llama_sample_p_step(
struct llama_context * ctx,
llama_token_data_array * candidates,
float step,
size_t min_keep);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down
132 changes: 94 additions & 38 deletions tests/test-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
}
}

static void test_p_step(const std::vector<float> & probs, const std::vector<float> & expected_probs, float step) {
const size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
llama_sample_p_step(nullptr, &candidates_p, step, 1);
DUMP(&candidates_p);
llama_sample_softmax(nullptr, &candidates_p);

GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) {
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
}
}

static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
Expand Down Expand Up @@ -149,7 +170,7 @@ static void test_repetition_penalties(
}

static void test_sampler_queue(
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p, const float p_step
) {
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand All @@ -164,14 +185,15 @@ static void test_sampler_queue(
const llama_token max_token_id = n_vocab-1;

for (auto s : samplers_sequence) {
switch (s){
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break;
switch (s) {
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
case 's': llama_sample_p_step (nullptr, &candidates_p, p_step, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break;
}

llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
Expand Down Expand Up @@ -218,6 +240,18 @@ static void test_sampler_queue(
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size));
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));

GGML_ASSERT(size == expected_size);
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
} else if (s == 's') {
min_token_id = n_vocab;
int expected_size = 0;

do { // do-while because always at least one token is sampled
min_token_id--;
expected_size++;
} while (candidates_p.data[expected_size].p >= p_step * candidates_p.data[expected_size - 1].p);

GGML_ASSERT(size == expected_size);
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
Expand All @@ -226,8 +260,8 @@ static void test_sampler_queue(
}
}

printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f p_step=%f\n",
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p, p_step);
}

int main(void) {
Expand All @@ -252,6 +286,17 @@ int main(void) {
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);

test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.0f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.5f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.6f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f);
test_p_step({0.2f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
// Disabled because of floating point nonsense: 0.3f < 0.75f * 0.4f is true!
//test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.75f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);

test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
Expand All @@ -267,33 +312,44 @@ int main(void) {
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);

test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);

test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);

test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "km", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f);

test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1e-12);

test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 9997.9f/9999.0f);
test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 0.1f);

test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "km", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f, 1.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "ks", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "sk", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 9997.9f/9999.0f);
test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 0.1f);

test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "ksp", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "skp", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "spk", 100, 0.8f, 1.0f, 0.1f);

printf("OK\n");

Expand Down
Loading