diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index aed8be1d7f..e6f8cf0849 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -179,6 +179,17 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->rstate->metrics.completion_tokens += accept_length; estate->metrics.spec_decode.Update(cum_verify_lengths[i + 1] - cum_verify_lengths[i], accept_length); + if (engine_config_->spec_tree_width == 1) { + // The roll back is needed for the chain draft case. + int rollback_length = + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); + if (rollback_length > 0) { + // The last accepted token is not yet added into the draft model. + // Therefore, the rollback length for the draft model is one less. + models_[draft_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); + } + } // Commit accepted tokens to the "verify_model", rollback kv cache // in the "draft_model". // NOTE: when number of small models is more than 1 (in the future),