diff --git a/CMakeLists.txt b/CMakeLists.txt index ced73ca03bfbc..335623bd2677d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" + "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/ops.h b/csrc/ops.h index f9feb3deff5e4..1e94a9f45ef08 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu new file mode 100644 index 0000000000000..0e537ddd6c4cd --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cu @@ -0,0 +1,131 @@ +/* + * The goal of this GPU kernel is to advance input tensors on the GPU directly + * PR: https://github.com/vllm-project/vllm/pull/6338 + * Current restrictions: + * 1. Specialized for DraftModelRunner + * 2. Supports flash_attn only + */ + +#include "advance_step.cuh" + +namespace prepare_inputs { + +// +template +__global__ void advance_step_kernel(int num_seqs, int num_queries, + int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, + long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, + int64_t const block_tables_stride) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id >= num_queries) { + return; + } + + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; +} + +inline void verify_tensor(std::string const& name, torch::Tensor& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int + + if (logging) { + printf("advance_step:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + advance_step_kernel<<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +} // namespace prepare_inputs + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables) { + prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables); +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh new file mode 100644 index 0000000000000..f21574681b1ab --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace prepare_inputs { + +static constexpr int max_threads = 256; +static constexpr bool logging = false; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +} // namespace prepare_inputs diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9dc7cefc404ca..ff9875e0e17a3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); + // prepare_inputs advance_step + ops.def("advance_step", &advance_step); + ops.impl("advance_step", torch::kCUDA, &advance_step); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 34a6c9a393a58..da72f6d503c11 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -227,6 +227,7 @@ def get_output_from_llm_generator( maybe_assert_ngram_worker(llm) outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 9832d4f267e8a..442e40f07f0bb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k(): assert proposals.proposal_lens.tolist() == [ k for _ in range(expected_num_proposal_seqs - 1) ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] + + +@torch.inference_mode() +def test_use_draft_model_runner_advance_step(): + """Verify that draft model runner triggers advance step + when applicable. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + k = 5 + batch_size = 32 + block_size = 32 + num_gpu_blocks = 2048 // block_size + worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + + # Mock "_gpu_advance_step" to raise an exception when called. + exception_secret = "artificial stop" + worker.model_runner._gpu_advance_step = MagicMock() + worker.model_runner._gpu_advance_step.side_effect = ValueError( + exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + + # Fallback (should not call) when num_steps=1. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=1) + worker.execute_model(execute_model_req=execute_model_req) + + # Expect exception if _gpu_advance_step is called. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + call_args_list = worker.model_runner._gpu_advance_step.call_args_list + assert len(call_args_list) == 1 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4ca67224a91b8..143957f7b65f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def advance_step(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, seq_lens: torch.Tensor, + slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step(num_seqs, num_queries, block_size, + input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, + block_tables) + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7cb8..5c376797a054f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,32 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + def forward( self, logits: torch.Tensor, @@ -60,12 +86,23 @@ def forward( assert logits is not None _, vocab_size = logits.shape - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: @@ -77,7 +114,7 @@ def forward( # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, @@ -109,13 +146,19 @@ def forward( on_device_tensors = None # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors) + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) @property def _should_modify_greedy_probs_inplace(self) -> bool: @@ -535,24 +578,29 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) + if not sampling_metadata.skip_sampler_cpu_output: + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + sample_results = [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + else: + sample_results = [] - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] return sample_results, sampled_token_ids_tensor @@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( sample_results: SampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: List[Optional[PromptLogprobs]], - sample_logprobs: List[SampleLogprobs], + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1010,22 +1059,26 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + if not skip_sampler_cpu_output: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index c346cd0562867..29b077cf6d912 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -87,6 +87,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + """ def __init__( @@ -95,11 +101,15 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors @staticmethod def prepare( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 90bba96ee8acb..3cb7ec58da4c1 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,17 +2,22 @@ import torch +from vllm import _custom_ops as ops +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) logger = init_logger(__name__) +debug_advance_input = False +enable_gpu_advance_step = True + class TP1DraftModelRunner(ModelRunner): """Specialized model runner for speculative decoding draft model. @@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner): we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time. - This runner is still under development so there's no performance gain - at this moment. Currently we adopt a temporary solution that caches the - seq_group_metadata_list for multi-step execution, so that we can - leverage existing prepare_model_input to be compatible with the current - execution flow, but we plan to remove this cache and avoid calling - prepare_model_input in execute_model at all. - - The detail development plan includes: - 1. Use "update_model_input" to update existing model_input without - creating a new one. - 2. Improve the performance of "update_model_input" with a GPU kernel. - 3. Support TP > 1 (this requires some designs because we do not expect + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model). """ @@ -71,51 +67,156 @@ def __init__( return_hidden_states=return_hidden_states, ) - # TODO: Remove this cache when we are able to update model_input - # directly in advance_step. - self.cached_seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - """A temporary solution that caches the seq_group_metadata_list - for multi-step execution. - TODO: In-place update model_input and remove this function. - """ - self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input( - seq_group_metadata_list, - finished_requests_ids=finished_requests_ids) + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) - def update_model_input( + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, last_output: SamplerOutput ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model inputs for the next step. - TODO: In-place update model_input instead of calling - prepare_model_input. + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config """ + if not enable_gpu_advance_step: + return False - # Append the output token to the sequence data. - assert self.cached_seq_group_metadata_list is not None - for seq_group_metadata, sequence_group_outputs in zip( - self.cached_seq_group_metadata_list, last_output.outputs): - seq_group_metadata.is_prompt = False + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False - for seq_output in sequence_group_outputs.samples: - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "flash-attn": + return False - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + # TODO: Add support for LORA + if self.lora_config: + return False - seq.append_token_id(token_id, token_logprob.logprob) - seq.update_num_computed_tokens(1) + # TODO: Add soft-tuning prompt adapter support + if self.prompt_adapter_config: + return False - return self.prepare_model_input(self.cached_seq_group_metadata_list) + return True @torch.inference_mode() def execute_model( @@ -125,42 +226,86 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - # Since we do not broadcast data inside execute_model anymore, - # we need to figure out the best way to support TP > 1 in this - # case, because we will at least need to broadcast the sampled - # tokens to all workers. - if not self.is_driver_worker: - raise ValueError("TP1DraftModelRunner only supports TP=1.") + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + else: + model_executable = self.model - virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[virtual_engine][graph_batch_size]) - else: - model_executable = self.model - multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -181,8 +326,8 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, )) - # Prepare the inputs for the next step. + # Prepare inputs for the next step if step != num_steps - 1: - model_input = self.update_model_input(model_input, outputs[-1]) + model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 11e99882e3f0b..91689324557b5 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -67,14 +67,23 @@ def sampler_output( expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) + # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance(self.model_runner, TP1DraftModelRunner): + if isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly expanded_request.num_steps = sample_len model_outputs = self.execute_model( execute_model_req=expanded_request) else: - # TODO: Remove this branch once DraftModelRunner supports TP>1. + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( execute_model_req=expanded_request) @@ -171,7 +180,7 @@ def _filter_model_output( outputs=[ expanded_batch_output.outputs[i] for i in output_indices_to_retain - ], + ] if len(expanded_batch_output.outputs) > 0 else [], sampled_token_probs=( expanded_batch_output. sampled_token_probs[output_indices_to_retain]