Skip to content

Commit

Permalink
speculative : draft sampling improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 31, 2023
1 parent d49869d commit c4ec4eb
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ int main(int argc, char ** argv) {
last_n_tokens.push_back(id);
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);

bool has_eos = false;

const auto t_gen_start = ggml_time_us();

while (true) {
n_past_dft -= drafted.size();
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));

// sample from the drafted tokens if any
Expand Down Expand Up @@ -131,8 +133,7 @@ int main(int argc, char ** argv) {
logits[it->first] += it->second;
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
Expand Down Expand Up @@ -235,31 +236,37 @@ int main(int argc, char ** argv) {
}

// sample n_draft tokens from the draft model picking the best token
int n_past_cur = n_past_dft;
for (int i = 0; i < n_draft; ++i) {
float * logits = llama_get_logits(ctx_dft);

int best_id = -1;
float best_logit = -1e30f;
float best_logit2 = -1e30f;
for (int j = 0; j < n_vocab; ++j) {
if (logits[j] > best_logit) {
best_logit2 = best_logit;
best_logit = logits[j];
best_id = j;
}
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };

// top-2 sampling
llama_sample_softmax(ctx_dft, &cur_p);

for (int i = 0; i < 3; ++i) {
LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p);
}

// very low confidence in the best token
// TODO: better way to do this
if (best_logit - best_logit2 < 1.0f) {
// too low probability, stop drafting
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
break;
}

drafted.push_back(best_id);
drafted.push_back(cur_p.data[0].id);
++n_drafted;

llama_eval(ctx_dft, &drafted.back(), 1, n_past_dft, params.n_threads);
++n_past_dft;
if (i < n_draft - 1) {
// evaluate the drafted token on the draft model
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
++n_past_cur;
}
}

// evaluate the target model on the drafted tokens
Expand Down

0 comments on commit c4ec4eb

Please sign in to comment.