Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samplers order parameters #4285

Merged
merged 13 commits into from
Dec 5, 2023
156 changes: 72 additions & 84 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "sampling.h"

#include <functional>

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

Expand Down Expand Up @@ -103,96 +101,89 @@ std::string llama_sampling_print(const llama_sampling_params & params) {

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";

std::unordered_map<char, std::string> samplers_map_display {
{'k', "-> top_k "},
{'f', "-> tfs_z "},
{'y', "-> typical_p "},
{'p', "-> top_p "},
{'m', "-> min_p "},
{'t', "-> temp "}
};
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved

if (params.mirostat == 0){
for (auto s : params.samplers_sequence){
result += samplers_map_display[s];
switch (s){
case 'k':{
result += "-> top_k ";
break;
}
case 'f':{
result += "-> tfs_z ";
break;
}
case 'y':{
result += "-> typical_p ";
break;
}
case 'p':{
result += "-> top_p ";
break;
}
case 'm':{
result += "-> min_p ";
break;
}
case 't':{
result += "-> temp ";
break;
}
default: break;
}
}
} else result += "-> mirostat ";

return result;
}

void sample_top_k(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
}

void sample_top_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float top_p = params.top_p;
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
}

void sample_tfs_z(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float tfs_z = params.tfs_z;
llama_sample_tail_free (ctx_main, &cur_p, tfs_z, min_keep);
}
// no reasons to expose this function in header
void sampler_queue(
struct llama_context * ctx_main,
const llama_sampling_params & params,
llama_token_data_array & cur_p,
size_t & min_keep) {
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

void sample_typical_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::string samplers_sequence = params.samplers_sequence;

for (auto s : samplers_sequence){
switch (s){
case 'k':{
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
break;
}
case 'f':{
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
break;
}
case 'y':{
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
break;
}
case 'p':{
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
break;
}
case 'm':{
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
break;
}
case 't':{
llama_sample_temp (ctx_main, &cur_p, temp);
break;
}
default: break;
}
}

const float typical_p = params.typical_p;
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
}

void sample_min_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float min_p = params.min_p;
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
}

void sample_temp(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float temp = params.temp;
llama_sample_temp (ctx_main, &cur_p, temp);
}

std::unordered_map<char, std::function<void(const llama_sampling_params &, struct llama_context *, llama_token_data_array&, size_t&)>> samplers_map
{
{'k', sample_top_k},
{'f', sample_tfs_z},
{'y', sample_typical_p},
{'p', sample_top_p},
{'m', sample_min_p},
{'t', sample_temp}
};

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand All @@ -211,7 +202,6 @@ llama_token llama_sampling_sample(
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
const std::string samplers_sequence = params.samplers_sequence;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
Expand Down Expand Up @@ -278,9 +268,7 @@ llama_token llama_sampling_sample(
// temperature sampling
size_t min_keep = std::max(1, params.n_probs);

for (auto s : samplers_sequence){
samplers_map[s](params, ctx_main, cur_p, min_keep);
}
sampler_queue(ctx_main, params, cur_p, min_keep);

id = llama_sample_token(ctx_main, &cur_p);

Expand Down
36 changes: 0 additions & 36 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,42 +84,6 @@ std::string llama_sampling_print(const llama_sampling_params & params);
// Print sampling order into a string
std::string llama_sampling_order_print(const llama_sampling_params & params);

void sample_top_k(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_top_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_tfs_z(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_typical_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_min_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_temp(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
Expand Down