From 98230ef656641597007d46854e005d16ff677377 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Mon, 4 Sep 2023 04:51:30 -0400 Subject: [PATCH] Add heuristic algo for speculative --- examples/speculative/speculative.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index f0400c13fc211..019b77efa9c5d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -84,7 +84,7 @@ int main(int argc, char ** argv) { //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); // how many tokens to draft each time - const int n_draft = params.n_draft; + int n_draft = params.n_draft; int n_predict = 0; int n_drafted = 0; @@ -116,6 +116,8 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; + bool all_accepted = false; + while (true) { const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); @@ -141,6 +143,9 @@ int main(int argc, char ** argv) { ++n_past_dft; ++i_dft; + if (i_dft == (int) drafted.size()) { + all_accepted = true; + } continue; } @@ -154,6 +159,14 @@ int main(int argc, char ** argv) { break; } + if (drafted.size() > 0 && all_accepted) { + n_draft += 2; + LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); + } else { + n_draft -= 1; + LOG("drafted token rejected, n_draft = %d\n", n_draft); + } + if (n_predict > params.n_predict || has_eos) { break; }