diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..2ff721da331fd 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -14,6 +14,7 @@ ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from .version import __version__, __version_tuple__ @@ -26,6 +27,7 @@ "TextPrompt", "TokensPrompt", "SamplingParams", + "SpecDecodeParams", "RequestOutput", "CompletionOutput", "PoolingOutput", diff --git a/vllm/config.py b/vllm/config.py index 307cf9c8d5b2a..852d8d5cf4eff 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1589,11 +1589,11 @@ def maybe_create_spec_config( # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set # draft related config as None here. - draft_model_config = target_model_config + from copy import deepcopy + draft_model_config = deepcopy(target_model_config) + draft_model_config.model = "[ngram]" draft_parallel_config = target_parallel_config else: - ngram_prompt_lookup_max = 0 - ngram_prompt_lookup_min = 0 draft_model_config = ModelConfig( model=speculative_model, task="draft", diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3bc6becf0995..6f2587baca329 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1372,6 +1372,7 @@ def schedule( pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, + spec_decode_params=seq_group.spec_decode_params, computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f50e20cf70323..111dd852a5881 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -33,6 +33,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import deprecate_kwargs, weak_bind @@ -435,6 +436,7 @@ async def add_request_async( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -449,6 +451,7 @@ async def add_request_async( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -466,6 +469,7 @@ async def add_request_async( params: Optional[Union[SamplingParams, PoolingParams]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -517,6 +521,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + spec_decode_params=spec_decode_params, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, @@ -917,6 +922,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -932,6 +938,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -950,6 +957,7 @@ async def add_request( params: Optional[Union[SamplingParams, PoolingParams]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -982,6 +990,7 @@ async def add_request( params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, @@ -995,6 +1004,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -1011,6 +1021,7 @@ async def generate( sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + spec_decode_params: The speculative decoding parameters, if any. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. @@ -1071,6 +1082,7 @@ async def generate( prompt, sampling_params, lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, @@ -1086,6 +1098,7 @@ async def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ) -> AsyncGenerator[PoolingRequestOutput, None]: @@ -1101,6 +1114,7 @@ async def encode( pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + spec_decode_params: The speculative decoding parameters, if any. trace_headers: OpenTelemetry trace headers. priority: The priority of the request. Only applicable with priority scheduling. @@ -1157,6 +1171,7 @@ async def encode( prompt, pooling_params, lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, priority=priority, ): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dc2d77d6927cd..8014f36871ba4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -50,6 +50,7 @@ PoolingSequenceGroupOutput, Sequence, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata, SequenceGroupOutput, SequenceStatus) +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -597,6 +598,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + spec_decode_params: Optional[SpecDecodeParams], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, @@ -646,6 +648,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, @@ -657,6 +660,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, + spec_decode_params=spec_decode_params, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, priority=priority) @@ -685,6 +689,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -701,6 +706,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -718,6 +724,7 @@ def add_request( params: Optional[Union[SamplingParams, PoolingParams]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -739,6 +746,7 @@ def add_request( :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + spec_decode_params: The speculative decoding parameters, if any. trace_headers: OpenTelemetry trace headers. priority: The priority of the request. Only applicable with priority scheduling. @@ -808,6 +816,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + spec_decode_params=spec_decode_params, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, @@ -841,6 +850,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, @@ -872,6 +882,7 @@ def _create_sequence_group_with_sampling( arrival_time=arrival_time, sampling_params=sampling_params, lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, @@ -886,6 +897,7 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + spec_decode_params: Optional[SpecDecodeParams], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, priority: int = 0, @@ -899,6 +911,7 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, + spec_decode_params=spec_decode_params, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 420f540d0b5f4..219a6dcfb725c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -10,6 +10,7 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -30,6 +31,7 @@ class RPCProcessRequest: params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None + spec_decode_params: Optional[SpecDecodeParams] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 @@ -41,6 +43,7 @@ def __init__( params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -56,6 +59,7 @@ def __init__( params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -72,6 +76,7 @@ def __init__( params: Optional[Union[SamplingParams, PoolingParams]] = None, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -89,6 +94,7 @@ def __init__( self.params = params self.request_id = request_id self.lora_request = lora_request + self.spec_decode_params = spec_decode_params self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request self.priority = priority diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 0a046c71e86e8..f08c8b868d47b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -39,6 +39,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import deprecate_kwargs @@ -421,6 +422,7 @@ def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -436,6 +438,7 @@ def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -452,6 +455,7 @@ def generate( sampling_params: Optional[SamplingParams] = None, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -470,6 +474,7 @@ def generate( sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + spec_decode_params: The speculative decoding parameters, if any. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. @@ -483,8 +488,9 @@ def generate( and request_id is not None) return self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, - prompt_adapter_request, priority) + lora_request, spec_decode_params, + trace_headers, prompt_adapter_request, + priority) @overload def encode( @@ -493,6 +499,7 @@ def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ) -> AsyncGenerator[PoolingRequestOutput, None]: @@ -507,6 +514,7 @@ def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ) -> AsyncGenerator[PoolingRequestOutput, None]: @@ -522,6 +530,7 @@ def encode( pooling_params: Optional[PoolingParams] = None, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, *, @@ -539,6 +548,7 @@ def encode( pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + spec_decode_params: The speculative decoding parameters, if any. trace_headers: OpenTelemetry trace headers. Yields: @@ -556,6 +566,7 @@ def encode( pooling_params, request_id, lora_request, + spec_decode_params, trace_headers, priority=priority)) @@ -565,6 +576,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -618,6 +630,7 @@ async def _process_request( params=params, request_id=request_id, lora_request=lora_request, + spec_decode_params=spec_decode_params, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 49a90b321dac4..a052f15261965 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -259,6 +259,7 @@ def _handle_process_request(self, request: RPCProcessRequest): prompt=request.prompt, params=request.params, lora_request=request.lora_request, + spec_decode_params=request.spec_decode_params, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request, priority=request.priority) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index a066836b92708..0a0f61c70474a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.spec_decode.spec_decode_params import SpecDecodeParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import collect_from_async_generator, random_uuid @@ -51,6 +52,7 @@ def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -206,6 +208,7 @@ def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ) -> AsyncGenerator[PoolingRequestOutput, None]: diff --git a/vllm/sequence.py b/vllm/sequence.py index cc3d96fc93a79..349ac86efb5a4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -19,6 +19,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.spec_decode.spec_decode_params import SpecDecodeParams VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -407,6 +408,7 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id @@ -414,6 +416,7 @@ def __init__( self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request + self.spec_decode_params = spec_decode_params self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData.from_seqs(self.prompt_token_ids) @@ -647,6 +650,7 @@ def __init__( arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, + spec_decode_params: Optional[SpecDecodeParams] = None, pooling_params: Optional[PoolingParams] = None, pooled_data: Optional[torch.Tensor] = None, encoder_seq: Optional[Sequence] = None, @@ -668,6 +672,7 @@ def __init__( first_token_time=None, time_in_queue=None) self.lora_request = lora_request + self.spec_decode_params = spec_decode_params self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.pooling_params = pooling_params @@ -899,6 +904,8 @@ class SequenceGroupMetadata( token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. lora_request: LoRA request. + spec_decode_params: The parameters used to guide proposal generation + for speculative decoding. computed_block_nums: The block numbers that are already computed, used in prefix caching. state: Internal state tied to this sequence group. @@ -924,6 +931,7 @@ class SequenceGroupMetadata( do_sample: bool = True pooling_params: Optional[PoolingParams] = None lora_request: Optional[LoRARequest] = None + spec_decode_params: Optional[SpecDecodeParams] = None computed_block_nums: Optional[List[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) diff --git a/vllm/spec_decode/multi_proposer_worker.py b/vllm/spec_decode/multi_proposer_worker.py new file mode 100644 index 0000000000000..4db3cb8000b6a --- /dev/null +++ b/vllm/spec_decode/multi_proposer_worker.py @@ -0,0 +1,283 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class MultiProposerWorker(ProposerWorkerBase, LoraNotSupportedWorkerBase): + + def __init__(self, *args, **kwargs): + self.vocab_size = kwargs["model_config"].get_vocab_size() + self._workers = kwargs.pop('worker_list', {}) + # Once we support more policies, we can make this configurable. + self.scheduling_policy = kwargs.pop('scheduling_policy', + 'proposal_latency') + + draft_parallel_config: ParallelConfig = kwargs['parallel_config'] + draft_tp = draft_parallel_config.tensor_parallel_size + + # TP>1 is not supported currently because DraftModelRunner does + # not support TP>1. + # TODO: Remove this when TP>1 is supported and #5814 is fixed. + if draft_tp != 1: + raise ValueError( + f"speculative_draft_tensor_parallel_size cannot be " + f"other value than 1 when using MultiProposerWorker. " + f"Got {draft_tp} instead.") + + def init_device(self) -> None: + for worker in self._workers.values(): + worker.init_device() + + def load_model(self) -> None: + for worker in self._workers.values(): + worker.load_model() + + def set_include_gpu_probs_tensor(self) -> None: + for worker in self._workers.values(): + if self.is_multi_step_worker_instance(worker): + worker.set_include_gpu_probs_tensor() + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: + """No need to implement sampler_output for MultiProposerWorker, + as the optional proposers of MultiProposerWorker will use their + own Top1Proposers to call their sampler_output functions. + """ + raise NotImplementedError + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. Ideally, we + receive and process all sequences in the same batch with one specified + proposer. However, if multiple proposers are specified, we currently + use the proposer with the lowest proposal latency for the whole batch. + """ + + # If we use different proposers for different sequences in the same + # batch, all proposers will need to wait for the slowest proposer to + # finish on each batch for further scoring. It means those proposers + # with lower acceptance rates but faster speed, like Ngram, will be + # dragged down by the slowest proposer for each step when there remain + # more steps for them to complete. Therefore, a better strategy is to + # use the fastest proposer adaptively among all specified proposers for + # the current batch. This could be optimized when we have multiple + # scorers. + if self.scheduling_policy == "divide_and_conquer": + return self._get_combined_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + else: + chosen_proposer = self._get_proposer_for_this_step( + execute_model_req, scheduling_policy=self.scheduling_policy) + + return self._workers[chosen_proposer].get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Perform speculative decoding on the input batch. + """ + + # To perform KV operations, the 'non_driver_ranks' of SpecDecodeWorker + # might call this function with execute_model_req set to None many + # times. + if execute_model_req is None: + return [] + + # Currently, if one seq_group requires to perform execute_model through + # MultiStepWorker, all seq_groups in the same batch have to perform + # execute_model together. We have not found a good way to avoid this. + proposer: str = '[ngram]' + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + valid_proposers = list(self._workers.keys()) + for _, seq in enumerate(seq_group_metadata_list): + sd_params = seq.spec_decode_params + if sd_params is not None: + proposer = sd_params.get_proposer() + if proposer not in valid_proposers: + logger.info( + "proposer_name must be in %s, or set to None. " + "Got '%s' instead. Use '[ngram]' as replacement.", + valid_proposers, proposer) + proposer = '[ngram]' + sd_params.set_proposer(proposer) + if self.is_multi_step_worker_instance(self._workers[proposer]): + break + else: + return [] + + return self._workers[proposer].execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + for worker in self._workers.values(): + if self.is_multi_step_worker_instance(worker): + return worker.get_cache_block_size_bytes() + + return 0 + + def determine_num_available_blocks(self) -> Tuple[int, int]: + for worker in self._workers.values(): + if self.is_multi_step_worker_instance(worker): + return worker.determine_num_available_blocks() + + return -1, -1 + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + for worker in self._workers.values(): + if self.is_multi_step_worker_instance(worker): + worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + return + + def _get_proposer_for_this_step( + self, + execute_model_req: Optional[ExecuteModelRequest], + scheduling_policy: Optional[str] = "proposal_latency", + ) -> str: + """Get the current proposer for the given sequence batch according to + required scheduling_policy. + """ + chosen_proposer = '[ngram]' + if execute_model_req is None: + return chosen_proposer + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + valid_proposers = list(self._workers.keys()) + + if scheduling_policy == "proposal_latency": + for _, seq in enumerate(seq_group_metadata_list): + sd_params = seq.spec_decode_params + if sd_params: + proposer = sd_params.get_proposer() + if proposer not in valid_proposers: + continue + else: + chosen_proposer = proposer + # Since MultiProposerWorker only supports Ngram as the + # backup proposer currently, we should use Ngram for + # the whole batch if any seq_group specifies it. + # TODO: Refactor this when flexible backup speculative + # model choices and latency metrics are supported. + if chosen_proposer == '[ngram]': + break + + elif scheduling_policy == "proposal_quality": + # TODO: Use SpecDecodeWorkerMetrics to select the proposer with the + # best draft_acceptance_rate dynamically. + raise NotImplementedError( + f"scheduling_policy: '{scheduling_policy}' has not been " + f"implemented yet.") + elif scheduling_policy == "given_priority": + # TODO: Select the proposer according to a given priority order + raise NotImplementedError( + f"scheduling_policy: '{scheduling_policy}' has not been " + f"implemented yet.") + + else: + raise ValueError( + f"Invalid scheduling_policy: '{scheduling_policy}'.") + + return chosen_proposer + + def _get_combined_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. This method + use multiple speculative proposers to generate speculations and return + the combined results. + """ + + proposer_requests: Dict[str, List[SequenceGroupMetadata]] = {} + original_indices: Dict[str, List[int]] = {} + valid_proposers = list(self._workers.keys()) + + # Split batch by proposer + for idx, seq in enumerate(execute_model_req.seq_group_metadata_list): + sd_params = seq.spec_decode_params + if sd_params: + proposer = sd_params.get_proposer() + if proposer not in valid_proposers: + # Got unknown proposer. Use '[ngram]' as default instead. + proposer = '[ngram]' + if proposer not in proposer_requests: + proposer_requests[proposer] = [] + original_indices[proposer] = [] + proposer_requests[proposer].append(seq) + original_indices[proposer].append(idx) + + all_proposals: Dict[str, SpeculativeProposals] = {} + + # Although we use ThreadPoolExecutor to get_spec_proposals for now, + # we still need to wait for the slowest proposer to finish on each + # batch for further scoring. + # TODO: Fix this when there are multiple scorer instances available for + # scoring. + with ThreadPoolExecutor() as executor: + futures = { + executor.submit(self._workers[proposer].get_spec_proposals, + execute_model_req.clone(sq_list), + seq_ids_with_bonus_token_in_last_step): + proposer + for proposer, sq_list in proposer_requests.items() + if len(sq_list) != 0 + } + + for future in futures: + proposer = futures[future] + all_proposals[proposer] = future.result() + + seq_group_metadata_length = len( + execute_model_req.seq_group_metadata_list) + merged_token_ids = [None] * seq_group_metadata_length + merged_probs = [None] * seq_group_metadata_length + merged_lens = [None] * seq_group_metadata_length + + # Combine and restore the original order of the proposals + for proposer, indices in original_indices.items(): + proposals = all_proposals[proposer] + if len(indices) != 0: + for i, idx in enumerate(indices): + merged_token_ids[idx] = proposals.proposal_token_ids[i] + merged_probs[idx] = proposals.proposal_probs[i] + merged_lens[idx] = proposals.proposal_lens[i] + + combined_proposals = SpeculativeProposals( + proposal_token_ids=torch.stack(merged_token_ids), + proposal_probs=torch.stack(merged_probs), + proposal_lens=torch.stack(merged_lens)) + return combined_proposals + + def is_multi_step_worker_instance(self, obj: ProposerWorkerBase) -> bool: + if isinstance(obj, MultiStepWorker): + return True + elif isinstance(obj, SmallerTpProposerWorker): + if hasattr(obj, '_worker'): + return self.is_multi_step_worker_instance(obj._worker) + else: + return False + else: + return False diff --git a/vllm/spec_decode/spec_decode_params.py b/vllm/spec_decode/spec_decode_params.py new file mode 100644 index 0000000000000..12ef52f6ab414 --- /dev/null +++ b/vllm/spec_decode/spec_decode_params.py @@ -0,0 +1,38 @@ +"""Parameters for speculative decoding.""" +import copy + + +class SpecDecodeParams: + """ + Parameters for Speculative Decoding choices and future features. + + Args: + proposer_name: Name of proposer to be used for SpecDecodeWorker. + """ + + def __init__( + self, + proposer_name: str, + ) -> None: + self.proposer_name = proposer_name + self._verify_args() + + def _verify_args(self) -> None: + if not isinstance(self.proposer_name, str) or not self.proposer_name: + raise ValueError("proposer_name (a non-empty string) must be " + "provided.") + + def clone(self) -> "SpecDecodeParams": + return copy.deepcopy(self) + + def get_proposer(self) -> str: + return self.proposer_name + + def set_proposer(self, proposer_name: str) -> None: + if not isinstance(proposer_name, str) or not proposer_name: + raise ValueError("proposer_name (a non-empty string) must be " + "provided.") + self.proposer_name = proposer_name + + def __repr__(self) -> str: + return (f"SpecDecodeParams(proposer_name={self.proposer_name})") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2689802161987..eb18ad3b4e05e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -30,6 +30,7 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.spec_decode.multi_proposer_worker import MultiProposerWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -158,7 +159,7 @@ def create_worker( draft_model_config = draft_worker_kwargs["vllm_config"].model_config draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config - if ngram_prompt_lookup_max > 0: + if draft_model_config.model == '[ngram]': draft_worker_kwargs[ "device_type"] = scorer_worker.device_config.device.type proposer_worker = NGramWorker(**draft_worker_kwargs) @@ -188,6 +189,24 @@ def create_worker( proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) + worker_list: Dict[str, ProposerWorkerBase] = {} + worker_list[draft_model_config.model] = proposer_worker + + # Currently, MultiProposerWorker is designed to support NGram + # proposer as a backup to pair up with another slower but more + # accurate proposer. If NGramWorker is not configured, then we do + # not need MultiProposerWorker at this moment. More flexible + # choices will be added in the future. + if ngram_prompt_lookup_max > 0: + backup_proposer_worker = NGramWorker(**draft_worker_kwargs) + backup_proposer_worker.set_ngram_window_size( + ngram_prompt_lookup_min, ngram_prompt_lookup_max) + worker_list['[ngram]'] = backup_proposer_worker + + if len(worker_list.keys()) > 1: + proposer_worker = MultiProposerWorker(**draft_worker_kwargs, + worker_list=worker_list) + logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker))