Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Spec Decode] Add multi-proposer support for variable and flexible speculative decoding #7947

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9fe0a00
Initial commit to add multi proposers.
ShangmingCai Aug 1, 2024
a4577c6
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 1, 2024
cee57e0
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 2, 2024
4aeef1f
fix
ShangmingCai Aug 2, 2024
5efcb22
fix
ShangmingCai Aug 2, 2024
41542a2
fix
ShangmingCai Aug 2, 2024
22c951a
fix format
ShangmingCai Aug 2, 2024
ce42293
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 2, 2024
c63a795
fix
ShangmingCai Aug 2, 2024
4cd1ce3
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 5, 2024
da2a189
fix
ShangmingCai Aug 6, 2024
6370ebb
wip
ShangmingCai Aug 7, 2024
d1cba3d
Merge main.
ShangmingCai Aug 7, 2024
b6dda43
fix conflict
ShangmingCai Aug 9, 2024
d535bc4
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 9, 2024
8c8826f
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 12, 2024
51148e6
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 14, 2024
1307a50
Fix conflict
ShangmingCai Aug 20, 2024
32dd50c
Add 'disable' option
ShangmingCai Aug 20, 2024
f63679f
fix yapf
ShangmingCai Aug 20, 2024
1a9a01a
Change disable logic
ShangmingCai Aug 21, 2024
f4cbd4c
wip
ShangmingCai Aug 26, 2024
65fa9ae
fix conflict
ShangmingCai Aug 26, 2024
d954b40
revert disable option
ShangmingCai Aug 27, 2024
00ff46a
fix
ShangmingCai Aug 27, 2024
6dd887d
fix conflict and merge
ShangmingCai Aug 27, 2024
5f2c75a
fix comment
ShangmingCai Aug 27, 2024
664710d
remove test temporarily
ShangmingCai Aug 27, 2024
e63cc08
Merge branch 'main' into add_multi_proposers
ShangmingCai Aug 28, 2024
f519d78
fix ruff
ShangmingCai Aug 28, 2024
def110a
fix openai compatibility
ShangmingCai Aug 28, 2024
0ded849
fix ruff
ShangmingCai Aug 28, 2024
07ab467
fix ruff
ShangmingCai Aug 28, 2024
e157e8d
fix
ShangmingCai Aug 28, 2024
082fcc4
fix
ShangmingCai Aug 28, 2024
67518c2
fix
ShangmingCai Aug 28, 2024
9e3c58a
fix mypy
ShangmingCai Aug 28, 2024
2a10dca
Fix merge conflict
ShangmingCai Sep 5, 2024
a87bb1e
Merge branch 'main' into add_multi_proposers
ShangmingCai Sep 12, 2024
a13ebd5
Rebase
ShangmingCai Dec 18, 2024
1a7f27c
fix missed conflict.
ShangmingCai Dec 18, 2024
2cf80ad
Fix merge conflict.
ShangmingCai Dec 18, 2024
5c115c6
fix again.
ShangmingCai Dec 18, 2024
af23e7c
adapt to latest RPCProcessRequest.
ShangmingCai Dec 18, 2024
c2d4711
Fix import of SamplerOutput.
ShangmingCai Dec 18, 2024
92f3bf1
Fix worker_list.
ShangmingCai Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
ShangmingCai committed Aug 6, 2024
commit da2a189b6163f7880add0ae6337d874d92c38617
1 change: 1 addition & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
"TextPrompt",
"TokensPrompt",
"SamplingParams",
"SpecDecodeParams",
"RequestOutput",
"CompletionOutput",
"EmbeddingOutput",
113 changes: 19 additions & 94 deletions vllm/spec_decode/multi_proposers_worker.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import Dict, List, Optional, Set, Tuple
from concurrent.futures import ThreadPoolExecutor

import torch

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase

@@ -76,7 +73,8 @@ def get_spec_proposals(
dragged down by the slowest proposer for each step when there remain
more steps for them to complete. Therefore, a better strategy is to
use the fastest proposer adaptively among all specified proposers for
the current batch.
the current batch. This could be optimized when we have multiple
scorers.
"""
chosen_proposer = self._get_proposer_for_this_step(
execute_model_req,
@@ -90,18 +88,17 @@ def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. This procedure
will respect the SpecDecodeParams of each sequence.
"""Perform speculative decoding on the input batch.
"""

# To perform KV operations, the 'non_driver_ranks' of SpecDecodeWorker
# might call this function with execute_model_req set to None for
# many times.
# might call this function with execute_model_req set to None many
# times.
if execute_model_req is None:
return []

# Curently, if one seq_group require to perform execute_model through
# MultiStepWorker, all seq_groups in the same batch need to perform
# Currently, if one seq_group requires to perform execute_model through
# MultiStepWorker, all seq_groups in the same batch have to perform
# execute_model together. We have not found a good way to avoid this.
proposer: str = '[ngram]'
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
@@ -160,7 +157,6 @@ def _get_proposer_for_this_step(

if schedule_policy == "popularity":
proposer_count: Dict[str, int] = {}
# Count the proposer choices
for seq in seq_group_metadata_list:
sd_params = seq.spec_decode_params
if sd_params is not None:
@@ -170,108 +166,37 @@ def _get_proposer_for_this_step(
if proposer not in proposer_count:
proposer_count[proposer] = 0
proposer_count[proposer] += 1
# Choose the proposer with the highest count
if len(proposer_count.keys()) != 0:
chosen_proposer = max(proposer_count, key=proposer_count.get)

elif schedule_policy == "proposal_latency":
for _, seq in enumerate(seq_group_metadata_list):
sd_params = seq.spec_decode_params
if sd_params:
# Since MultiProposersWorker only supports Ngram as the
# backup proposer currently, we should use Ngram for the
# whole batch if any seq_group specifies it.
# TODO: Refactor this when flexible backup speculative
# model choices and latency metrics are supported.
proposer = sd_params.get_proposer()
if proposer == '[ngram]':
chosen_proposer = proposer
break
if proposer not in valid_proposers:
continue
else:
chosen_proposer = proposer
# Since MultiProposersWorker only supports Ngram as the
# backup proposer currently, we should use Ngram for
# the whole batch if any seq_group specifies it.
# TODO: Refactor this when flexible backup speculative
# model choices and latency metrics are supported.
if chosen_proposer == '[ngram]':
break

elif schedule_policy == "proposal_quality":
# TODO: Use SpecDecodeWorkerMetrics to select the proposer with
# best draft_acceptance_rate.
# TODO: Use SpecDecodeWorkerMetrics to select the proposer with the
# best draft_acceptance_rate dynamically.
raise NotImplementedError(
f"schedule_policy: '{schedule_policy}' has not been "
f"implemented yet.")

else:
raise ValueError(f"Invalid schedule_policy: '{schedule_policy}'.")

return chosen_proposer

def _get_combined_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. This method
use multiple speculative proposers to generate speculations and return
the combined results.
"""

# Get speculative proposals from each proposer.
proposer_requests: Dict[str, List[SequenceGroupMetadata]] = {}
original_indices: Dict[str, List[int]] = {}
valid_proposers = list(self._workers.keys())

for idx, seq in enumerate(execute_model_req.seq_group_metadata_list):
sd_params = seq.spec_decode_params
if sd_params:
proposer = sd_params.get_proposer()
if proposer not in valid_proposers:
# Got unknown proposer. Use '[ngram]' as default instead.
proposer = '[ngram]'
if proposer not in proposer_requests:
proposer_requests[proposer] = []
original_indices[proposer] = []
proposer_requests[proposer].append(seq)
original_indices[proposer].append(idx)

all_proposals: Dict[str, SpeculativeProposals] = {}

# Although we use ThreadPoolExecutor to get_spec_proposals concurently,
# we still need to wait for the slowest proposer to finish on each
# batch for further scoring.
# TODO: Fix this when there are multiple scorer instances available for
# scoring.
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(self._workers[proposer].get_spec_proposals,
execute_model_req.clone(sq_list),
seq_ids_with_bonus_token_in_last_step):
proposer
for proposer, sq_list in proposer_requests.items()
if len(sq_list) != 0}

for future in futures:
proposer = futures[future]
all_proposals[proposer] = future.result()

seq_group_metadata_length = len(
execute_model_req.seq_group_metadata_list)
merged_token_ids = [None] * seq_group_metadata_length
merged_probs = [None] * seq_group_metadata_length
merged_lens = [None] * seq_group_metadata_length

# Combine and restore the original order of the proposals
for proposer, indices in original_indices.items():
proposals = all_proposals[proposer]
if len(indices) != 0:
for i, idx in enumerate(indices):
merged_token_ids[idx] = proposals.proposal_token_ids[i]
merged_probs[idx] = proposals.proposal_probs[i]
merged_lens[idx] = proposals.proposal_lens[i]

combined_proposals = SpeculativeProposals(
proposal_token_ids=torch.stack(merged_token_ids),
proposal_probs=torch.stack(merged_probs),
proposal_lens=torch.stack(merged_lens)
)
return combined_proposals

def is_multi_step_worker_instance(self, obj: ProposerWorkerBase) -> bool:
if isinstance(obj, MultiStepWorker):