Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heuristic algorithm for speculative #3006

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));

// how many tokens to draft each time
const int n_draft = params.n_draft;
int n_draft = params.n_draft;

int n_predict = 0;
int n_drafted = 0;
Expand Down Expand Up @@ -134,6 +134,7 @@ int main(int argc, char ** argv) {
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));

int i_dft = 0;

while (true) {
// sample from the target model
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
Expand Down Expand Up @@ -177,6 +178,27 @@ int main(int argc, char ** argv) {
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
++n_past_dft;

// heuristic for n_draft
{
const int n_draft_cur = (int) drafted.size();
const bool all_accepted = i_dft == n_draft_cur;

LOG("n_draft = %d\n", n_draft);
LOG("n_draft_cur = %d\n", n_draft_cur);
LOG("i_dft = %d\n", i_dft);
LOG("all_accepted = %d\n", all_accepted);

if (all_accepted && n_draft == n_draft_cur) {
LOG(" - max drafted tokens accepted - n_draft += 8\n");
n_draft = std::min(30, n_draft + 8);
} else if (all_accepted) {
LOG(" - partially drafted tokens accepted - no change\n");
} else {
LOG(" - drafted token rejected - n_draft -= 1\n");
n_draft = std::max(2, n_draft - 1);
}
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leng-yue

I've refactored the implementation to be more contained.
Also, we were rewarding the draft even when it hasn't sampled all n_draft tokens, which does not seem correct.

For example, let's say n_draft is 16 currently, but the draft samples just 3 tokens because the "low-probability" check has been triggered. Even if all 3 tokens were accepted, we should not reward the draft model, because this is just a small part of what we asked it to do.

Regarding the reproducibility - we will study this more in #3014
My guess is that this behavior would occur even without the heuristic - probably it's just less likely to happen for some reason when the heuristic is disabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my earlier implementation, rewards were given only when all tokens were accepted. So if only 3 out of 16 tokens are accepted, the n_tokens value would be decreased by 1.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. This check does not guarantee you have n_draft tokens accepted:

if (i_dft == (int) drafted.size()) {
all_accepted = true;
}

The reason is because drafted.size() <= n_draft due to another heuristic of not drafting more tokens if the drafter becomes "unsure":

// too low probability, stop drafting
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
break;
}

So in the majority of cases, when all_accepted == true, you would have accepted less than n_draft tokens. That's why your n_draft would increase so much even beyond 70 in some cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

drafted.clear();
drafted.push_back(id);

Expand Down
Loading