Skip to content

Commit

Permalink
[Fix][Spec] Fix incorrect top-p normalization in batch verification (#…
Browse files Browse the repository at this point in the history
…2940)

This PR fixes a critical bug which does not properly set the top-p
value for the probabilities computed by the main model in batch
verification. This bug leads to a huge decrease of the speculative
decoding accuracy.
  • Loading branch information
MasterJH5574 authored Sep 25, 2024
1 parent 9336b4a commit 0ed2179
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
9 changes: 7 additions & 2 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class BatchVerifyActionObj : public EngineActionObj {
Array<RequestModelState> verify_request_mstates;
Array<RequestModelState> draft_request_mstates;
Array<GenerationConfig> generation_cfg;
Array<GenerationConfig> generation_cfg_for_top_p_norm;
std::vector<RandomGenerator*> rngs;
std::vector<std::vector<SampleResult>> draft_output_tokens;
std::vector<int64_t> token_tree_parent_ptr;
Expand All @@ -76,6 +77,7 @@ class BatchVerifyActionObj : public EngineActionObj {
draft_request_mstates.reserve(num_rsentries);
rngs.reserve(num_rsentries);
generation_cfg.reserve(num_rsentries);
generation_cfg_for_top_p_norm.reserve(total_verify_length);
draft_output_tokens.reserve(num_rsentries);
draft_token_slots_.clear();

Expand All @@ -90,13 +92,15 @@ class BatchVerifyActionObj : public EngineActionObj {
draft_token_slots_.push_back(0); // placeholder for the last committed token
all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());
token_tree_parent_ptr.push_back(-1);
generation_cfg_for_top_p_norm.push_back(rsentries[i]->request->generation_cfg);
std::vector<int> cur_draft_token_indices;
cur_draft_token_indices.resize(draft_mstate->draft_output_tokens.size() + 1);
std::iota(cur_draft_token_indices.begin(), cur_draft_token_indices.end(), -1);
for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {
all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());
draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);
token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);
generation_cfg_for_top_p_norm.push_back(rsentries[i]->request->generation_cfg);
}
draft_token_indices.emplace_back(std::move(cur_draft_token_indices));
verify_request_mstates.push_back(verify_mstate);
Expand Down Expand Up @@ -141,10 +145,11 @@ class BatchVerifyActionObj : public EngineActionObj {
// Note: we commit prefix cache changes here to overlap this commit with the GPU execution.
estate->prefix_cache->CommitSequenceExtention();

std::vector<int> sample_indices(num_rsentries);
// Fill range [0, total_verify_length) into `sample_indices`.
std::vector<int> sample_indices(total_verify_length);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
probs_on_device, sample_indices, request_ids, generation_cfg_for_top_p_norm);
auto [sample_results_arr, last_accepted_tree_node_verify_model] =
sampler_->BatchVerifyDraftTokensWithProbAfterTopP(
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
Expand Down
1 change: 0 additions & 1 deletion cpp/serve/sampler/gpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class GPUSampler : public SamplerObj {
int num_probs = probs_on_device->shape[0];
int vocab_size = probs_on_device->shape[1];
ICHECK_LE(num_probs, max_num_sample_);
ICHECK_EQ(request_ids.size(), num_samples);
ICHECK_EQ(generation_cfg.size(), num_samples);

// - Check if there is need for applying top p.
Expand Down

0 comments on commit 0ed2179

Please sign in to comment.