Skip to content

Commit

Permalink
common : add llama_sample_token helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 3, 2023
1 parent 260b4a5 commit fdc53e2
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 208 deletions.
124 changes: 124 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
return result;
}

//
// Sampling utils
//

llama_token llama_sample_token(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_grammar * grammar,
const struct gpt_params & params,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);

const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

llama_token id = 0;

float * logits = llama_get_logits(ctx) + idx * n_vocab;

// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}

candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };

if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
}

// apply penalties
if (!last_tokens.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);

llama_sample_repetition_penalty(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}

if (grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, grammar);
}

if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k (ctx, &cur_p, top_k, 1);
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
llama_sample_typical (ctx, &cur_p, typical_p, 1);
llama_sample_top_p (ctx, &cur_p, top_p, 1);
llama_sample_temperature(ctx, &cur_p, temp);

{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);

for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}

id = llama_sample_token(ctx, &cur_p);

LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
// printf("`%d`", candidates_p.size);

if (grammar != NULL) {
llama_grammar_accept_token(ctx, grammar, id);
}

return id;
}

//
// YAML utils
//

// returns true if successful, false otherwise
bool create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
Expand Down
34 changes: 34 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,40 @@ std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);

//
// Sampling utils
//

// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
//
// required:
// - ctx: context to use for sampling
// - params: sampling parameters
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - grammar: grammar to use for sampling, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sample_token(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_grammar * grammar,
const struct gpt_params & params,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
int idx = 0);

//
// YAML utils
//

bool create_directory_with_parents(const std::string & path);
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
Expand Down
134 changes: 20 additions & 114 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,9 @@ int main(int argc, char ** argv) {
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");

struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
llama_grammar * grammar = NULL;

if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
Expand All @@ -450,8 +451,8 @@ int main(int argc, char ** argv) {
}

// TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);

if (params.interactive) {
const char *control_message;
Expand Down Expand Up @@ -500,6 +501,11 @@ int main(int argc, char ** argv) {
llama_reset_timings(ctx);
}

const int n_vocab = llama_n_vocab(ctx);

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (embd.size() > 0) {
Expand Down Expand Up @@ -537,8 +543,8 @@ int main(int argc, char ** argv) {

LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
// insert n_left/2 tokens at the start of embd from last_tokens
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());

LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));

Expand Down Expand Up @@ -637,20 +643,6 @@ int main(int argc, char ** argv) {
embd_guidance.clear();

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
Expand All @@ -659,98 +651,12 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}

llama_token id = 0;

{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);

// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };

if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
}

// Apply penalties
float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}

if (grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, grammar);
}

if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k (ctx, &cur_p, top_k, 1);
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
llama_sample_typical (ctx, &cur_p, typical_p, 1);
llama_sample_top_p (ctx, &cur_p, top_p, 1);
llama_sample_temperature(ctx, &cur_p, temp);

{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);

for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}

id = llama_sample_token(ctx, &cur_p);
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);

LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
// printf("`%d`", candidates_p.size);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);

if (grammar != NULL) {
llama_grammar_accept_token(ctx, grammar, id);
}

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens));
}
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));

embd.push_back(id);

Expand All @@ -766,8 +672,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
Expand Down Expand Up @@ -800,7 +706,7 @@ int main(int argc, char ** argv) {
// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
for (auto id : last_tokens) {
last_output += llama_token_to_piece(ctx, id);
}

Expand Down Expand Up @@ -831,7 +737,7 @@ int main(int argc, char ** argv) {
}

// deal with end of text token in interactive mode
if (last_n_tokens.back() == llama_token_eos(ctx)) {
if (last_tokens.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");

if (params.interactive) {
Expand Down Expand Up @@ -933,7 +839,7 @@ int main(int argc, char ** argv) {
if (grammar != NULL) {
llama_grammar_free(grammar);

std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules());
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
Expand Down
Loading

0 comments on commit fdc53e2

Please sign in to comment.