From 98230ef656641597007d46854e005d16ff677377 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Mon, 4 Sep 2023 04:51:30 -0400 Subject: [PATCH 1/5] Add heuristic algo for speculative --- examples/speculative/speculative.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index f0400c13fc211..019b77efa9c5d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -84,7 +84,7 @@ int main(int argc, char ** argv) { //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); // how many tokens to draft each time - const int n_draft = params.n_draft; + int n_draft = params.n_draft; int n_predict = 0; int n_drafted = 0; @@ -116,6 +116,8 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; + bool all_accepted = false; + while (true) { const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); @@ -141,6 +143,9 @@ int main(int argc, char ** argv) { ++n_past_dft; ++i_dft; + if (i_dft == (int) drafted.size()) { + all_accepted = true; + } continue; } @@ -154,6 +159,14 @@ int main(int argc, char ** argv) { break; } + if (drafted.size() > 0 && all_accepted) { + n_draft += 2; + LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); + } else { + n_draft -= 1; + LOG("drafted token rejected, n_draft = %d\n", n_draft); + } + if (n_predict > params.n_predict || has_eos) { break; } From 9248528d6ee56ea87918ad60e345b53d58110764 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Mon, 4 Sep 2023 06:59:51 -0400 Subject: [PATCH 2/5] Constrain minimum n_draft to 2 --- examples/speculative/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 019b77efa9c5d..4610be59f2b44 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -163,7 +163,7 @@ int main(int argc, char ** argv) { n_draft += 2; LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); } else { - n_draft -= 1; + n_draft = std::max(2, n_draft - 1); LOG("drafted token rejected, n_draft = %d\n", n_draft); } From dddd784c4de23af43d851f323f75c8c9bb874492 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Sep 2023 08:49:40 +0300 Subject: [PATCH 3/5] speculative : improve heuristic impl --- examples/speculative/speculative.cpp | 33 ++++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 4610be59f2b44..51562bcb17b65 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -116,7 +116,6 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; - bool all_accepted = false; while (true) { const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); @@ -143,9 +142,6 @@ int main(int argc, char ** argv) { ++n_past_dft; ++i_dft; - if (i_dft == (int) drafted.size()) { - all_accepted = true; - } continue; } @@ -153,20 +149,33 @@ int main(int argc, char ** argv) { llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; + // heuristic for n_draft + { + const int n_dradt_cur = (int) drafted.size(); + const bool all_accepted = i_dft == n_dradt_cur; + + LOG("n_draft = %d\n", n_draft); + LOG("n_draft_cur = %d\n", n_dradt_cur); + LOG("i_dft = %d\n", i_dft); + LOG("all_accepted = %d\n", all_accepted); + + if (all_accepted && n_draft == n_dradt_cur) { + LOG(" - max drafted tokens accepted - n_draft += 2\n"); + n_draft += 2; + } else if (all_accepted) { + LOG(" - partially drafted tokens accepted - no change\n"); + } else { + LOG(" - drafted token rejected - n_draft -= 1\n"); + n_draft = std::max(2, n_draft - 1); + } + } + drafted.clear(); drafted.push_back(id); break; } - if (drafted.size() > 0 && all_accepted) { - n_draft += 2; - LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); - } else { - n_draft = std::max(2, n_draft - 1); - LOG("drafted token rejected, n_draft = %d\n", n_draft); - } - if (n_predict > params.n_predict || has_eos) { break; } From d9559b78f3bea1a95b927eca685641c4c0082fc5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Sep 2023 08:54:08 +0300 Subject: [PATCH 4/5] speculative : be more rewarding upon guessing max drafted tokens --- examples/speculative/speculative.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index a081da936f79c..a0b836a133031 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -189,8 +189,8 @@ int main(int argc, char ** argv) { LOG("all_accepted = %d\n", all_accepted); if (all_accepted && n_draft == n_dradt_cur) { - LOG(" - max drafted tokens accepted - n_draft += 2\n"); - n_draft += 2; + LOG(" - max drafted tokens accepted - n_draft += 8\n"); + n_draft = std::min(30, n_draft + 8); } else if (all_accepted) { LOG(" - partially drafted tokens accepted - no change\n"); } else { From b5efa62504ee09d79bdbc7cb199870ffdc6e9942 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Sep 2023 08:55:49 +0300 Subject: [PATCH 5/5] speculative : fix typos --- examples/speculative/speculative.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index a0b836a133031..0fd828cc97f27 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -180,15 +180,15 @@ int main(int argc, char ** argv) { // heuristic for n_draft { - const int n_dradt_cur = (int) drafted.size(); - const bool all_accepted = i_dft == n_dradt_cur; + const int n_draft_cur = (int) drafted.size(); + const bool all_accepted = i_dft == n_draft_cur; LOG("n_draft = %d\n", n_draft); - LOG("n_draft_cur = %d\n", n_dradt_cur); + LOG("n_draft_cur = %d\n", n_draft_cur); LOG("i_dft = %d\n", i_dft); LOG("all_accepted = %d\n", all_accepted); - if (all_accepted && n_draft == n_dradt_cur) { + if (all_accepted && n_draft == n_draft_cur) { LOG(" - max drafted tokens accepted - n_draft += 8\n"); n_draft = std::min(30, n_draft + 8); } else if (all_accepted) {