Skip to content

Commit

Permalink
[Fix][Spec] Add the draft model PopN in chain speculative decoding (#…
Browse files Browse the repository at this point in the history
…2939)

This PR fixes a bug which missed to add the PopN to erase the unaccepted
tokens for the draft model, when the draft is in chain shape and no
tree draft is enabled.
  • Loading branch information
MasterJH5574 authored Sep 25, 2024
1 parent cb333ea commit 9336b4a
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ class BatchVerifyActionObj : public EngineActionObj {
rsentries[i]->rstate->metrics.completion_tokens += accept_length;
estate->metrics.spec_decode.Update(cum_verify_lengths[i + 1] - cum_verify_lengths[i],
accept_length);
if (engine_config_->spec_tree_width == 1) {
// The roll back is needed for the chain draft case.
int rollback_length =
std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);
if (rollback_length > 0) {
// The last accepted token is not yet added into the draft model.
// Therefore, the rollback length for the draft model is one less.
models_[draft_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1);
}
}
// Commit accepted tokens to the "verify_model", rollback kv cache
// in the "draft_model".
// NOTE: when number of small models is more than 1 (in the future),
Expand Down

0 comments on commit 9336b4a

Please sign in to comment.