-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add heuristic algorithm for speculative #3006
Changes from 1 commit
98230ef
9248528
dddd784
4e6e951
d9559b7
b5efa62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,7 @@ int main(int argc, char ** argv) { | |
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); | ||
|
||
// how many tokens to draft each time | ||
const int n_draft = params.n_draft; | ||
int n_draft = params.n_draft; | ||
|
||
int n_predict = 0; | ||
int n_drafted = 0; | ||
|
@@ -116,6 +116,8 @@ int main(int argc, char ** argv) { | |
|
||
// sample from the drafted tokens if any | ||
int i_dft = 0; | ||
bool all_accepted = false; | ||
|
||
while (true) { | ||
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); | ||
|
||
|
@@ -141,6 +143,9 @@ int main(int argc, char ** argv) { | |
++n_past_dft; | ||
++i_dft; | ||
|
||
if (i_dft == (int) drafted.size()) { | ||
all_accepted = true; | ||
} | ||
continue; | ||
} | ||
|
||
|
@@ -154,6 +159,14 @@ int main(int argc, char ** argv) { | |
break; | ||
} | ||
|
||
if (drafted.size() > 0 && all_accepted) { | ||
n_draft += 2; | ||
LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); | ||
} else { | ||
n_draft -= 1; | ||
LOG("drafted token rejected, n_draft = %d\n", n_draft); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is what I got after restricting the minimum Outputs
Full log: speculative.139912996806656.log As a comparison, this one doesn't include heuristic algorithm: Outputs
Full log: speculative.140657253044224.log There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The models I am using are: codellama-7b.Q4_K_M.gguf and codellama-34b.Q4_K_M.gguf. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can you try instead of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am using cuBLAS backend, and I got same output after changing |
||
|
||
if (n_predict > params.n_predict || has_eos) { | ||
break; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does
n_draft
go up by2
when all drafted tokens are accepted, but decrease by1
when a drafted token is rejected? Is there a more efficient algorithm to handle this? The current approach seems similar to a simplified version of TCP Friendly Rate Control algorithm.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's pretty much borrowed from Hugging Face's code. We could fine-tune it by tweaking some parameters. Since getting all tokens right is challenging, it seems reasonable to bump up
n_draft
by 2 when everything aligns and decrease it by 1 otherwise.